JiantaoLin
commited on
Commit
·
02a9751
1
Parent(s):
e33401c
new
Browse files- app.py +431 -323
- demo.py → app_demo.py +98 -38
- app_flux.py +0 -141
- image_to_mesh.py +0 -437
- image_to_mesh_new.py +0 -436
- live_preview_helpers.py +0 -167
- models/llm/__pycache__/llm.cpython-310.pyc +0 -0
- models/llm/llm.py +97 -0
- pipeline/custom_pipelines/__init__.py +3 -0
- pipeline/custom_pipelines/pipeline_flux_controlnet_image_to_image.py +1004 -0
- pipeline/custom_pipelines/pipeline_flux_img2img.py +862 -0
- pipeline/custom_pipelines/pipeline_flux_prior_redux.py +500 -0
- pipeline/example_text_to_3d.py +7 -0
- pipeline/kiss3d_wrapper.py +416 -75
- pipeline/pipeline_config/default.yaml +16 -8
- pipeline/run_hpc.sh +1 -1
- run_hpc.sh → pipeline/run_hpc_text_to_3d.sh +3 -4
- pipeline/utils.py +34 -6
- run.sh +0 -2
- text_to_mesh.py +0 -232
- text_to_mesh_new.py +0 -244
- upload_huggingface.py +0 -57
- video_render.py +123 -0
app.py
CHANGED
@@ -1,9 +1,35 @@
|
|
1 |
-
import gradio as gr
|
2 |
import os
|
|
|
3 |
import subprocess
|
4 |
-
import shlex
|
5 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
access_token = os.getenv("HUGGINGFACE_TOKEN")
|
8 |
subprocess.run(
|
9 |
shlex.split(
|
@@ -22,6 +48,8 @@ subprocess.run(
|
|
22 |
"pip install ./extension/renderutils_plugin-0.1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
|
23 |
)
|
24 |
)
|
|
|
|
|
25 |
def install_cuda_toolkit():
|
26 |
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
|
27 |
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
|
@@ -41,6 +69,7 @@ def install_cuda_toolkit():
|
|
41 |
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
|
42 |
print("==> finfish install")
|
43 |
install_cuda_toolkit()
|
|
|
44 |
@spaces.GPU
|
45 |
def check_gpu():
|
46 |
os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
|
@@ -48,338 +77,417 @@ def check_gpu():
|
|
48 |
# os.environ['LD_LIBRARY_PATH'] += ':/usr/local/cuda-12.1/lib64'
|
49 |
os.environ['LD_LIBRARY_PATH'] = "/usr/local/cuda-12.1/lib64:" + os.environ.get('LD_LIBRARY_PATH', '')
|
50 |
subprocess.run(['nvidia-smi']) # 测试 CUDA 是否可用
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
print(f"torch.cuda.is_available:{torch.cuda.is_available()}")
|
52 |
check_gpu()
|
53 |
|
54 |
-
from
|
55 |
-
from einops import rearrange
|
56 |
-
from diffusers import FluxPipeline
|
57 |
-
from models.lrm.utils.camera_util import get_flux_input_cameras
|
58 |
-
from models.lrm.utils.infer_util import save_video
|
59 |
-
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
|
60 |
-
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
61 |
-
from models.lrm.utils.train_util import instantiate_from_config
|
62 |
-
from models.ISOMER.reconstruction_func import reconstruction
|
63 |
-
from models.ISOMER.projection_func import projection
|
64 |
-
import os
|
65 |
-
from einops import rearrange
|
66 |
-
from omegaconf import OmegaConf
|
67 |
-
import torch
|
68 |
-
import numpy as np
|
69 |
-
import trimesh
|
70 |
-
import torchvision
|
71 |
-
import torch.nn.functional as F
|
72 |
-
from PIL import Image
|
73 |
-
from torchvision import transforms
|
74 |
-
from torchvision.transforms import v2
|
75 |
-
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
|
76 |
-
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
|
77 |
-
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
|
78 |
-
from diffusers import FluxPipeline
|
79 |
-
from pytorch_lightning import seed_everything
|
80 |
-
import os
|
81 |
-
from huggingface_hub import hf_hub_download
|
82 |
-
|
83 |
-
|
84 |
-
from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
|
85 |
-
|
86 |
-
device_0 = "cuda"
|
87 |
-
device_1 = "cuda"
|
88 |
-
resolution = 512
|
89 |
-
save_dir = "./outputs"
|
90 |
-
normal_transfer = NormalTransfer()
|
91 |
-
isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device_1)
|
92 |
-
isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device_1)
|
93 |
-
isomer_radius = 4.5
|
94 |
-
isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device_1)
|
95 |
-
isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device_1)
|
96 |
-
|
97 |
-
# model initialization and loading
|
98 |
-
# flux
|
99 |
-
# # taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device_0)
|
100 |
-
# # good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, token=access_token).to(device_0)
|
101 |
-
# flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(device=device_0, dtype=torch.bfloat16)
|
102 |
-
# # flux_pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, vae=taef1, token=access_token).to(device_0)
|
103 |
-
# flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
|
104 |
-
# flux_pipe.load_lora_weights(flux_lora_ckpt_path)
|
105 |
-
# flux_pipe.to(device=device_0, dtype=torch.bfloat16)
|
106 |
-
# torch.cuda.empty_cache()
|
107 |
-
# flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(flux_pipe)
|
108 |
-
|
109 |
-
|
110 |
-
# lrm
|
111 |
-
config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
|
112 |
-
model_config = config.model_config
|
113 |
-
infer_config = config.infer_config
|
114 |
-
model = instantiate_from_config(model_config)
|
115 |
-
model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
116 |
-
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
117 |
-
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
118 |
-
model.load_state_dict(state_dict, strict=True)
|
119 |
-
model = model.to(device_1)
|
120 |
-
torch.cuda.empty_cache()
|
121 |
-
@spaces.GPU
|
122 |
-
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
|
123 |
-
images = image.unsqueeze(0).to(device_1)
|
124 |
-
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
|
125 |
-
# breakpoint()
|
126 |
-
with torch.no_grad():
|
127 |
-
# get triplane
|
128 |
-
planes = model.forward_planes(images, input_cameras)
|
129 |
-
|
130 |
-
mesh_path_idx = os.path.join(save_path, f'{name}.obj')
|
131 |
-
|
132 |
-
mesh_out = model.extract_mesh(
|
133 |
-
planes,
|
134 |
-
use_texture_map=export_texmap,
|
135 |
-
**infer_config,
|
136 |
-
)
|
137 |
-
if export_texmap:
|
138 |
-
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
139 |
-
save_obj_with_mtl(
|
140 |
-
vertices.data.cpu().numpy(),
|
141 |
-
uvs.data.cpu().numpy(),
|
142 |
-
faces.data.cpu().numpy(),
|
143 |
-
mesh_tex_idx.data.cpu().numpy(),
|
144 |
-
tex_map.permute(1, 2, 0).data.cpu().numpy(),
|
145 |
-
mesh_path_idx,
|
146 |
-
)
|
147 |
-
else:
|
148 |
-
vertices, faces, vertex_colors = mesh_out
|
149 |
-
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
|
150 |
-
print(f"Mesh saved to {mesh_path_idx}")
|
151 |
-
|
152 |
-
render_size = 512
|
153 |
-
if if_save_video:
|
154 |
-
video_path_idx = os.path.join(save_path, f'{name}.mp4')
|
155 |
-
render_size = infer_config.render_resolution
|
156 |
-
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
157 |
-
materials = (0.0,0.9)
|
158 |
-
|
159 |
-
all_mv, all_mvp, all_campos = get_render_cameras_video(
|
160 |
-
batch_size=1,
|
161 |
-
M=24,
|
162 |
-
radius=4.5,
|
163 |
-
elevation=(90, 60.0),
|
164 |
-
is_flexicubes=True,
|
165 |
-
fov=30
|
166 |
-
)
|
167 |
-
|
168 |
-
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
169 |
-
model,
|
170 |
-
planes,
|
171 |
-
render_cameras=all_mvp,
|
172 |
-
camera_pos=all_campos,
|
173 |
-
env=ENV,
|
174 |
-
materials=materials,
|
175 |
-
render_size=render_size,
|
176 |
-
chunk_size=20,
|
177 |
-
is_flexicubes=True,
|
178 |
-
)
|
179 |
-
normals = (torch.nn.functional.normalize(normals) + 1) / 2
|
180 |
-
normals = normals * alphas + (1-alphas)
|
181 |
-
all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
|
182 |
-
|
183 |
-
save_video(
|
184 |
-
all_frames,
|
185 |
-
video_path_idx,
|
186 |
-
fps=30,
|
187 |
-
)
|
188 |
-
print(f"Video saved to {video_path_idx}")
|
189 |
|
190 |
-
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
-
def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
|
194 |
-
if local_normal_images.min() >= 0:
|
195 |
-
local_normal = local_normal_images.float() * 2 - 1
|
196 |
-
else:
|
197 |
-
local_normal = local_normal_images.float()
|
198 |
-
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
|
199 |
-
global_normal[...,0] *= -1
|
200 |
-
global_normal = (global_normal + 1) / 2
|
201 |
-
global_normal = global_normal.permute(0, 3, 1, 2)
|
202 |
-
return global_normal
|
203 |
-
|
204 |
-
# 生成多视图图像
|
205 |
-
@spaces.GPU(duration=120)
|
206 |
-
def generate_multi_view_images(prompt, seed):
|
207 |
-
# torch.cuda.empty_cache()
|
208 |
-
# generator = torch.manual_seed(seed)
|
209 |
-
generator = torch.Generator().manual_seed(seed)
|
210 |
-
with torch.no_grad():
|
211 |
-
img = flux_pipe(
|
212 |
-
prompt=prompt,
|
213 |
-
num_inference_steps=5,
|
214 |
-
guidance_scale=3.5,
|
215 |
-
num_images_per_prompt=1,
|
216 |
-
width=resolution * 2,
|
217 |
-
height=resolution * 1,
|
218 |
-
output_type='np',
|
219 |
-
generator=generator,
|
220 |
-
).images
|
221 |
-
# for img in flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
222 |
-
# prompt=prompt,
|
223 |
-
# guidance_scale=3.5,
|
224 |
-
# num_inference_steps=4,
|
225 |
-
# width=resolution * 4,
|
226 |
-
# height=resolution * 2,
|
227 |
-
# generator=generator,
|
228 |
-
# output_type="np",
|
229 |
-
# good_vae=good_vae,
|
230 |
-
# ):
|
231 |
-
# pass
|
232 |
-
# 返回最终的图像和种子(通过外部调用处理)
|
233 |
-
return img
|
234 |
-
|
235 |
-
# 重建 3D 模型
|
236 |
@spaces.GPU
|
237 |
-
def
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
# multi_view_mask: B,H,W
|
268 |
-
# rgb_multi_view: B,H,W,3
|
269 |
|
270 |
-
|
271 |
-
normal_pils=global_normal,
|
272 |
-
masks=multi_view_mask,
|
273 |
-
weights=isomer_geo_weights,
|
274 |
-
fov=30,
|
275 |
-
radius=isomer_radius,
|
276 |
-
camera_angles_azi=isomer_azimuths,
|
277 |
-
camera_angles_ele=isomer_elevations,
|
278 |
-
expansion_weight_stage1=0.1,
|
279 |
-
init_type="file",
|
280 |
-
init_verts=vertices,
|
281 |
-
init_faces=faces,
|
282 |
-
stage1_steps=0,
|
283 |
-
stage2_steps=50,
|
284 |
-
start_edge_len_stage1=0.1,
|
285 |
-
end_edge_len_stage1=0.02,
|
286 |
-
start_edge_len_stage2=0.02,
|
287 |
-
end_edge_len_stage2=0.005,
|
288 |
-
)
|
289 |
|
|
|
|
|
290 |
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
)
|
302 |
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
|
305 |
-
# Gradio 接口函数
|
306 |
@spaces.GPU
|
307 |
-
def
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
#
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
|
383 |
-
|
384 |
-
|
385 |
-
demo.launch()
|
|
|
|
|
1 |
import os
|
2 |
+
import gradio as gr
|
3 |
import subprocess
|
|
|
4 |
import spaces
|
5 |
+
import ctypes
|
6 |
+
import shlex
|
7 |
+
import base64
|
8 |
+
import re
|
9 |
+
import sys
|
10 |
+
|
11 |
+
from models.ISOMER.scripts.utils import fix_vert_color_glb
|
12 |
+
|
13 |
+
sys.path.append(os.path.abspath(os.path.join(__file__, '../')))
|
14 |
+
if 'OMP_NUM_THREADS' not in os.environ:
|
15 |
+
os.environ['OMP_NUM_THREADS'] = '32'
|
16 |
+
|
17 |
+
import shutil
|
18 |
import torch
|
19 |
+
import json
|
20 |
+
import requests
|
21 |
+
import shutil
|
22 |
+
import threading
|
23 |
+
from PIL import Image
|
24 |
+
import time
|
25 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
26 |
+
import trimesh
|
27 |
+
|
28 |
+
import random
|
29 |
+
import time
|
30 |
+
import numpy as np
|
31 |
+
from video_render import render_video_from_obj
|
32 |
+
|
33 |
access_token = os.getenv("HUGGINGFACE_TOKEN")
|
34 |
subprocess.run(
|
35 |
shlex.split(
|
|
|
48 |
"pip install ./extension/renderutils_plugin-0.1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
|
49 |
)
|
50 |
)
|
51 |
+
|
52 |
+
# download cudatoolkit
|
53 |
def install_cuda_toolkit():
|
54 |
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
|
55 |
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
|
|
|
69 |
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
|
70 |
print("==> finfish install")
|
71 |
install_cuda_toolkit()
|
72 |
+
|
73 |
@spaces.GPU
|
74 |
def check_gpu():
|
75 |
os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
|
|
|
77 |
# os.environ['LD_LIBRARY_PATH'] += ':/usr/local/cuda-12.1/lib64'
|
78 |
os.environ['LD_LIBRARY_PATH'] = "/usr/local/cuda-12.1/lib64:" + os.environ.get('LD_LIBRARY_PATH', '')
|
79 |
subprocess.run(['nvidia-smi']) # 测试 CUDA 是否可用
|
80 |
+
# 显式加载 libnvrtc.so.12
|
81 |
+
cuda_lib_path = "/usr/local/cuda-12.1/lib64/libnvrtc.so.12"
|
82 |
+
try:
|
83 |
+
ctypes.CDLL(cuda_lib_path, mode=ctypes.RTLD_GLOBAL)
|
84 |
+
print(f"Successfully preloaded {cuda_lib_path}")
|
85 |
+
except OSError as e:
|
86 |
+
print(f"Failed to preload {cuda_lib_path}: {e}")
|
87 |
print(f"torch.cuda.is_available:{torch.cuda.is_available()}")
|
88 |
check_gpu()
|
89 |
|
90 |
+
from pipeline.kiss3d_wrapper import init_wrapper_from_config, run_text_to_3d, run_image_to_3d, image2mesh_preprocess, image2mesh_main
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
is_running = False
|
93 |
|
94 |
+
TEXT_URL = "http://127.0.0.1:9239/prompt"
|
95 |
+
IMG_URL = ""
|
96 |
+
|
97 |
+
|
98 |
+
KISS_3D_TEXT_FOLDER = "./outputs/text2"
|
99 |
+
KISS_3D_IMG_FOLDER = "./outputs/image2"
|
100 |
+
|
101 |
+
# Add logo file path and hyperlinks
|
102 |
+
LOGO_PATH = "app_assets/logo_temp_.png" # Update this to the actual path of your logo
|
103 |
+
ARXIV_LINK = "https://arxiv.org/abs/example"
|
104 |
+
GITHUB_LINK = "https://github.com/example"
|
105 |
+
|
106 |
+
k3d_wrapper = init_wrapper_from_config('./pipeline/pipeline_config/default.yaml')
|
107 |
+
|
108 |
+
|
109 |
+
TEMP_MESH_ADDRESS=''
|
110 |
+
|
111 |
+
mesh_cache = None
|
112 |
+
preprocessed_input_image = None
|
113 |
+
|
114 |
+
def save_cached_mesh():
|
115 |
+
global mesh_cache
|
116 |
+
return mesh_cache
|
117 |
+
# if mesh_cache is None:
|
118 |
+
# return None
|
119 |
+
# return save_py3dmesh_with_trimesh_fast(mesh_cache)
|
120 |
+
|
121 |
+
def save_py3dmesh_with_trimesh_fast(meshes, save_glb_path=TEMP_MESH_ADDRESS, apply_sRGB_to_LinearRGB=True):
|
122 |
+
from pytorch3d.structures import Meshes
|
123 |
+
import trimesh
|
124 |
+
|
125 |
+
# convert from pytorch3d meshes to trimesh mesh
|
126 |
+
vertices = meshes.verts_packed().cpu().float().numpy()
|
127 |
+
triangles = meshes.faces_packed().cpu().long().numpy()
|
128 |
+
np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
|
129 |
+
if save_glb_path.endswith(".glb"):
|
130 |
+
# rotate 180 along +Y
|
131 |
+
vertices[:, [0, 2]] = -vertices[:, [0, 2]]
|
132 |
+
|
133 |
+
def srgb_to_linear(c_srgb):
|
134 |
+
c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
|
135 |
+
return c_linear.clip(0, 1.)
|
136 |
+
if apply_sRGB_to_LinearRGB:
|
137 |
+
np_color = srgb_to_linear(np_color)
|
138 |
+
assert vertices.shape[0] == np_color.shape[0]
|
139 |
+
assert np_color.shape[1] == 3
|
140 |
+
assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}"
|
141 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
|
142 |
+
mesh.remove_unreferenced_vertices()
|
143 |
+
# save mesh
|
144 |
+
mesh.export(save_glb_path)
|
145 |
+
if save_glb_path.endswith(".glb"):
|
146 |
+
fix_vert_color_glb(save_glb_path)
|
147 |
+
print(f"saving to {save_glb_path}")
|
148 |
+
#
|
149 |
+
#
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
@spaces.GPU
|
152 |
+
def text_to_detailed(prompt, seed=None):
|
153 |
+
print(f"Before text_to_detailed: {torch.cuda.memory_allocated() / 1024**3} GB")
|
154 |
+
return k3d_wrapper.get_detailed_prompt(prompt, seed)
|
155 |
+
|
156 |
+
@spaces.GPU
|
157 |
+
def text_to_image(prompt, seed=None, strength=1.0,lora_scale=1.0, num_inference_steps=30, redux_hparam=None, init_image=None, **kwargs):
|
158 |
+
print(f"Before text_to_image: {torch.cuda.memory_allocated() / 1024**3} GB")
|
159 |
+
k3d_wrapper.renew_uuid()
|
160 |
+
init_image = None
|
161 |
+
if init_image_path is not None:
|
162 |
+
init_image = Image.open(init_image_path)
|
163 |
+
result = k3d_wrapper.generate_3d_bundle_image_text(
|
164 |
+
prompt,
|
165 |
+
image=init_image,
|
166 |
+
strength=strength,
|
167 |
+
lora_scale=lora_scale,
|
168 |
+
num_inference_steps=num_inference_steps,
|
169 |
+
seed=int(seed) if seed is not None else None,
|
170 |
+
redux_hparam=redux_hparam,
|
171 |
+
save_intermediate_results=True,
|
172 |
+
**kwargs)
|
173 |
+
return result[-1]
|
174 |
+
|
175 |
+
def image2mesh_preprocess_(input_image_, seed, use_mv_rgb=True):
|
176 |
+
global preprocessed_input_image
|
177 |
+
|
178 |
+
seed = int(seed) if seed is not None else None
|
179 |
+
|
180 |
+
# TODO: delete this later
|
181 |
+
k3d_wrapper.del_llm_model()
|
|
|
|
|
182 |
|
183 |
+
input_image_save_path, reference_save_path, caption = image2mesh_preprocess(k3d_wrapper, input_image_, seed, use_mv_rgb)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
+
preprocessed_input_image = Image.open(input_image_save_path)
|
186 |
+
return reference_save_path, caption
|
187 |
|
188 |
+
@spaces.GPU
|
189 |
+
def image2mesh_main_(reference_3d_bundle_image, caption, seed, strength1=0.5, strength2=0.95, enable_redux=True, use_controlnet=True, if_video=True):
|
190 |
+
global mesh_cache
|
191 |
+
seed = int(seed) if seed is not None else None
|
192 |
+
|
193 |
+
|
194 |
+
# TODO: delete this later
|
195 |
+
k3d_wrapper.del_llm_model()
|
196 |
+
|
197 |
+
input_image = preprocessed_input_image
|
|
|
198 |
|
199 |
+
reference_3d_bundle_image = torch.tensor(reference_3d_bundle_image).permute(2,0,1)/255
|
200 |
+
|
201 |
+
gen_save_path, recon_mesh_path = image2mesh_main(k3d_wrapper, input_image, reference_3d_bundle_image, caption=caption, seed=seed, strength1=strength1, strength2=strength2, enable_redux=enable_redux, use_controlnet=use_controlnet)
|
202 |
+
mesh_cache = recon_mesh_path
|
203 |
+
|
204 |
+
|
205 |
+
# gen_save_ = Image.open(gen_save_path)
|
206 |
+
|
207 |
+
if if_video:
|
208 |
+
video_path = recon_mesh_path.replace('.obj','.mp4').replace('.glb','.mp4')
|
209 |
+
render_video_from_obj(recon_mesh_path, video_path)
|
210 |
+
print(f"After bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
|
211 |
+
return gen_save_path, video_path
|
212 |
+
else:
|
213 |
+
return gen_save_path, recon_mesh_path
|
214 |
+
# return gen_save_path, recon_mesh_path
|
215 |
|
|
|
216 |
@spaces.GPU
|
217 |
+
def bundle_image_to_mesh(
|
218 |
+
gen_3d_bundle_image,
|
219 |
+
lrm_radius = 4.15,
|
220 |
+
isomer_radius = 4.5,
|
221 |
+
reconstruction_stage1_steps = 10,
|
222 |
+
reconstruction_stage2_steps = 50,
|
223 |
+
save_intermediate_results=True,
|
224 |
+
if_video=True
|
225 |
+
):
|
226 |
+
global mesh_cache
|
227 |
+
print(f"Before bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
|
228 |
+
|
229 |
+
# TODO: delete this later
|
230 |
+
k3d_wrapper.del_llm_model()
|
231 |
+
|
232 |
+
print(f"Before bundle_image_to_mesh after deleting llm model: {torch.cuda.memory_allocated() / 1024**3} GB")
|
233 |
+
|
234 |
+
gen_3d_bundle_image = torch.tensor(gen_3d_bundle_image).permute(2,0,1)/255
|
235 |
+
# recon from 3D Bundle image
|
236 |
+
recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, lrm_render_radius=lrm_radius, isomer_radius=isomer_radius, save_intermediate_results=save_intermediate_results, reconstruction_stage1_steps=int(reconstruction_stage1_steps), reconstruction_stage2_steps=int(reconstruction_stage2_steps))
|
237 |
+
mesh_cache = recon_mesh_path
|
238 |
+
|
239 |
+
if if_video:
|
240 |
+
video_path = recon_mesh_path.replace('.obj','.mp4').replace('.glb','.mp4')
|
241 |
+
# # 检查这个video_path文件大小是是否超过50KB,不超过的话就认为是空文件,需要重新渲染
|
242 |
+
# if os.path.exists(video_path):
|
243 |
+
# print(f"file size:{os.path.getsize(video_path)}")
|
244 |
+
# if os.path.getsize(video_path) > 50*1024:
|
245 |
+
# print(f"video path:{video_path}")
|
246 |
+
# return video_path
|
247 |
+
render_video_from_obj(recon_mesh_path, video_path)
|
248 |
+
print(f"After bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
|
249 |
+
return video_path
|
250 |
+
else:
|
251 |
+
return recon_mesh_path
|
252 |
+
|
253 |
+
_HEADER_=f"""
|
254 |
+
<img src="{LOGO_PATH}">
|
255 |
+
<h2><b>Official 🤗 Gradio Demo</b></h2><h2>
|
256 |
+
<b>Kiss3DGen: Repurposing Image Diffusion Models for 3D Asset Generation</b></a></h2>
|
257 |
+
|
258 |
+
<p>**Kiss3DGen** is xxxxxxxxx</p>
|
259 |
+
|
260 |
+
[]({ARXIV_LINK}) []({GITHUB_LINK})
|
261 |
+
"""
|
262 |
+
|
263 |
+
_CITE_ = r"""
|
264 |
+
<h2>If Kiss3DGen is helpful, please help to ⭐ the <a href='{""" + GITHUB_LINK + r"""}' target='_blank'>Github Repo</a>. Thanks!</h2>
|
265 |
+
|
266 |
+
📝 **Citation**
|
267 |
+
|
268 |
+
If you find our work useful for your research or applications, please cite using this bibtex:
|
269 |
+
```bibtex
|
270 |
+
@article{xxxx,
|
271 |
+
title={xxxx},
|
272 |
+
author={xxxx},
|
273 |
+
journal={xxxx},
|
274 |
+
year={xxxx}
|
275 |
+
}
|
276 |
+
```
|
277 |
+
|
278 |
+
📋 **License**
|
279 |
+
|
280 |
+
Apache-2.0 LICENSE. Please refer to the [LICENSE file](https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/LICENSE) for details.
|
281 |
+
|
282 |
+
📧 **Contact**
|
283 |
+
|
284 |
+
If you have any questions, feel free to open a discussion or contact us at <b>xxx@xxxx</b>.
|
285 |
+
"""
|
286 |
+
|
287 |
+
def image_to_base64(image_path):
|
288 |
+
"""Converts an image file to a base64-encoded string."""
|
289 |
+
with open(image_path, "rb") as img_file:
|
290 |
+
return base64.b64encode(img_file.read()).decode('utf-8')
|
291 |
+
|
292 |
+
def main():
|
293 |
+
|
294 |
+
torch.set_grad_enabled(False)
|
295 |
+
|
296 |
+
# Convert the logo image to base64
|
297 |
+
logo_base64 = image_to_base64(LOGO_PATH)
|
298 |
+
# with gr.Blocks() as demo:
|
299 |
+
with gr.Blocks(css="""
|
300 |
+
body {
|
301 |
+
display: flex;
|
302 |
+
justify-content: center;
|
303 |
+
align-items: center;
|
304 |
+
min-height: 100vh;
|
305 |
+
margin: 0;
|
306 |
+
padding: 0;
|
307 |
+
}
|
308 |
+
#col-container { margin: 0px auto; max-width: 200px; }
|
309 |
+
|
310 |
+
|
311 |
+
.gradio-container {
|
312 |
+
max-width: 1000px;
|
313 |
+
margin: auto;
|
314 |
+
width: 100%;
|
315 |
+
}
|
316 |
+
#center-align-column {
|
317 |
+
display: flex;
|
318 |
+
justify-content: center;
|
319 |
+
align-items: center;
|
320 |
+
}
|
321 |
+
#right-align-column {
|
322 |
+
display: flex;
|
323 |
+
justify-content: flex-end;
|
324 |
+
align-items: center;
|
325 |
+
}
|
326 |
+
h1 {text-align: center;}
|
327 |
+
h2 {text-align: center;}
|
328 |
+
h3 {text-align: center;}
|
329 |
+
p {text-align: center;}
|
330 |
+
img {text-align: right;}
|
331 |
+
.right {
|
332 |
+
display: block;
|
333 |
+
margin-left: auto;
|
334 |
+
}
|
335 |
+
.center {
|
336 |
+
display: block;
|
337 |
+
margin-left: auto;
|
338 |
+
margin-right: auto;
|
339 |
+
width: 50%;
|
340 |
+
|
341 |
+
#content-container {
|
342 |
+
max-width: 1200px;
|
343 |
+
margin: 0 auto;
|
344 |
+
}
|
345 |
+
#example-container {
|
346 |
+
max-width: 300px;
|
347 |
+
margin: 0 auto;
|
348 |
+
}
|
349 |
+
""",elem_id="col-container") as demo:
|
350 |
+
# Header Section
|
351 |
+
# gr.Image(value=LOGO_PATH, width=64, height=64)
|
352 |
+
# gr.Markdown(_HEADER_)
|
353 |
+
with gr.Row(elem_id="content-container"):
|
354 |
+
# with gr.Column(scale=1):
|
355 |
+
# pass
|
356 |
+
# with gr.Column(scale=1, elem_id="right-align-column"):
|
357 |
+
# # gr.Image(value=LOGO_PATH, interactive=False, show_label=False, width=64, height=64, elem_id="logo-image")
|
358 |
+
# # gr.Markdown(f"<img src='{LOGO_PATH}' alt='Logo' style='width:64px;height:64px;border:0;'>")
|
359 |
+
# # gr.HTML(f"<img src='data:image/png;base64,{logo_base64}' alt='Logo' class='right' style='width:64px;height:64px;border:0;text-align:right;'>")
|
360 |
+
# pass
|
361 |
+
with gr.Column(scale=7, elem_id="center-align-column"):
|
362 |
+
gr.Markdown(f"""
|
363 |
+
## Official 🤗 Gradio Demo
|
364 |
+
# Kiss3DGen: Repurposing Image Diffusion Models for 3D Asset Generation""")
|
365 |
+
gr.HTML(f"<img src='data:image/png;base64,{logo_base64}' alt='Logo' class='center' style='width:64px;height:64px;border:0;text-align:center;'>")
|
366 |
+
|
367 |
+
gr.HTML(f"""
|
368 |
+
<div style="display: flex; justify-content: center; align-items: center; gap: 10px;">
|
369 |
+
<a href="{ARXIV_LINK}" target="_blank">
|
370 |
+
<img src="https://img.shields.io/badge/arXiv-Link-red" alt="arXiv">
|
371 |
+
</a>
|
372 |
+
<a href="{GITHUB_LINK}" target="_blank">
|
373 |
+
<img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub">
|
374 |
+
</a>
|
375 |
+
</div>
|
376 |
+
|
377 |
+
""")
|
378 |
+
|
379 |
+
|
380 |
+
# gr.HTML(f"""
|
381 |
+
# <div style="display: flex; gap: 10px; align-items: center;"><a href="{ARXIV_LINK}" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/arXiv-Link-red" alt="arXiv"></a> <a href="{GITHUB_LINK}" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub"></a></div>
|
382 |
+
# """)
|
383 |
+
|
384 |
+
# gr.Markdown(f"""
|
385 |
+
# []({ARXIV_LINK}) []({GITHUB_LINK})
|
386 |
+
# """, elem_id="title")
|
387 |
+
# with gr.Column(scale=1):
|
388 |
+
# pass
|
389 |
+
# with gr.Row():
|
390 |
+
# gr.Markdown(f"[]({ARXIV_LINK})")
|
391 |
+
# gr.Markdown(f"[]({GITHUB_LINK})")
|
392 |
+
|
393 |
+
# Tabs Section
|
394 |
+
with gr.Tabs(selected='tab_text_to_3d', elem_id="content-container") as main_tabs:
|
395 |
+
with gr.TabItem('Text-to-3D', id='tab_text_to_3d'):
|
396 |
+
with gr.Row():
|
397 |
+
with gr.Column(scale=1):
|
398 |
+
prompt = gr.Textbox(value="", label="Input Prompt", lines=4)
|
399 |
+
seed1 = gr.Number(value=10, label="Seed")
|
400 |
+
|
401 |
+
with gr.Row(elem_id="example-container"):
|
402 |
+
gr.Examples(
|
403 |
+
examples=[
|
404 |
+
# ["A tree with red leaves"],
|
405 |
+
# ["A dragon with black texture"],
|
406 |
+
["A girl with pink hair"],
|
407 |
+
["A boy playing guitar"],
|
408 |
+
|
409 |
+
|
410 |
+
["A dog wearing a hat"],
|
411 |
+
["A boy playing basketball"],
|
412 |
+
# [""],
|
413 |
+
# [""],
|
414 |
+
# [""],
|
415 |
+
|
416 |
+
],
|
417 |
+
inputs=[prompt], # 将选中的示例填入 prompt 文本框
|
418 |
+
label="Example Prompts"
|
419 |
+
)
|
420 |
+
btn_text2detailed = gr.Button("Refine to detailed prompt")
|
421 |
+
detailed_prompt = gr.Textbox(value="", label="Detailed Prompt", placeholder="detailed prompt will be generated here base on your input prompt. You can also edit this prompt", lines=4, interactive=True)
|
422 |
+
btn_text2img = gr.Button("Generate Images")
|
423 |
+
|
424 |
+
with gr.Column(scale=1):
|
425 |
+
output_image1 = gr.Image(label="Generated image", interactive=False)
|
426 |
+
|
427 |
+
|
428 |
+
# lrm_radius = gr.Number(value=4.15, label="lrm_radius")
|
429 |
+
# isomer_radius = gr.Number(value=4.5, label="isomer_radius")
|
430 |
+
# reconstruction_stage1_steps = gr.Number(value=10, label="reconstruction_stage1_steps")
|
431 |
+
# reconstruction_stage2_steps = gr.Number(value=50, label="reconstruction_stage2_steps")
|
432 |
+
|
433 |
+
btn_gen_mesh = gr.Button("Generate Mesh")
|
434 |
+
output_video1 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
|
435 |
+
btn_download1 = gr.Button("Download Mesh")
|
436 |
+
|
437 |
+
file_output1 = gr.File()
|
438 |
+
|
439 |
+
with gr.TabItem('Image-to-3D', id='tab_image_to_3d'):
|
440 |
+
with gr.Row():
|
441 |
+
with gr.Column(scale=1):
|
442 |
+
image = gr.Image(label="Input Image", type="pil")
|
443 |
+
|
444 |
+
seed2 = gr.Number(value=10, label="Seed (0 for random)")
|
445 |
+
|
446 |
+
btn_img2mesh_preprocess = gr.Button("Preprocess Image")
|
447 |
+
|
448 |
+
image_caption = gr.Textbox(value="", label="Image Caption", placeholder="caption will be generated here base on your input image. You can also edit this caption", lines=4, interactive=True)
|
449 |
+
|
450 |
+
output_image2 = gr.Image(label="Generated image", interactive=False)
|
451 |
+
strength1 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.5, label="strength1")
|
452 |
+
strength2 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.95, label="strength2")
|
453 |
+
enable_redux = gr.Checkbox(label="enable redux", value=True)
|
454 |
+
use_controlnet = gr.Checkbox(label="use controlnet", value=True)
|
455 |
+
|
456 |
+
btn_img2mesh_main = gr.Button("Generate Mesh")
|
457 |
+
|
458 |
+
with gr.Column(scale=1):
|
459 |
+
|
460 |
+
# output_mesh2 = gr.Model3D(label="Generated Mesh", interactive=False)
|
461 |
+
output_image3 = gr.Image(label="gen save image", interactive=False)
|
462 |
+
output_video2 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
|
463 |
+
btn_download2 = gr.Button("Download Mesh")
|
464 |
+
file_output2 = gr.File()
|
465 |
+
|
466 |
+
# Image2
|
467 |
+
btn_img2mesh_preprocess.click(fn=image2mesh_preprocess_, inputs=[image, seed2], outputs=[output_image2, image_caption])
|
468 |
+
|
469 |
+
btn_img2mesh_main.click(fn=image2mesh_main_, inputs=[output_image2, image_caption, seed2, strength1, strength2, enable_redux, use_controlnet], outputs=[output_image3, output_video2])
|
470 |
+
|
471 |
+
|
472 |
+
btn_download2.click(fn=save_cached_mesh, inputs=[], outputs=file_output2)
|
473 |
+
|
474 |
+
|
475 |
+
# Button Click Events
|
476 |
+
# Text2
|
477 |
+
btn_text2detailed.click(fn=text_to_detailed, inputs=[prompt, seed1], outputs=detailed_prompt)
|
478 |
+
btn_text2img.click(fn=text_to_image, inputs=[detailed_prompt, seed1], outputs=output_image1)
|
479 |
+
btn_gen_mesh.click(fn=bundle_image_to_mesh, inputs=[output_image1,], outputs=output_video1)
|
480 |
+
# btn_gen_mesh.click(fn=bundle_image_to_mesh, inputs=[output_image1, lrm_radius, isomer_radius, reconstruction_stage1_steps, reconstruction_stage2_steps], outputs=output_video1)
|
481 |
+
|
482 |
+
with gr.Row():
|
483 |
+
pass
|
484 |
+
with gr.Row():
|
485 |
+
gr.Markdown(_CITE_)
|
486 |
+
|
487 |
+
# demo.queue(default_concurrency_limit=1)
|
488 |
+
# demo.launch(server_name="0.0.0.0", server_port=9239)
|
489 |
+
demo.launch()
|
490 |
+
|
491 |
|
492 |
+
if __name__ == "__main__":
|
493 |
+
main()
|
|
demo.py → app_demo.py
RENAMED
@@ -4,11 +4,10 @@ import subprocess
|
|
4 |
import shlex
|
5 |
import spaces
|
6 |
import torch
|
7 |
-
import numpy as numpy
|
8 |
access_token = os.getenv("HUGGINGFACE_TOKEN")
|
9 |
subprocess.run(
|
10 |
shlex.split(
|
11 |
-
"pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/
|
12 |
)
|
13 |
)
|
14 |
|
@@ -20,7 +19,7 @@ subprocess.run(
|
|
20 |
|
21 |
subprocess.run(
|
22 |
shlex.split(
|
23 |
-
"pip install ./extension/renderutils_plugin-1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
|
24 |
)
|
25 |
)
|
26 |
def install_cuda_toolkit():
|
@@ -41,7 +40,7 @@ def install_cuda_toolkit():
|
|
41 |
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
|
42 |
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
|
43 |
print("==> finfish install")
|
44 |
-
|
45 |
@spaces.GPU
|
46 |
def check_gpu():
|
47 |
os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
|
@@ -84,8 +83,8 @@ from huggingface_hub import hf_hub_download
|
|
84 |
|
85 |
from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
|
86 |
|
87 |
-
device_0 = "cuda
|
88 |
-
device_1 = "cuda
|
89 |
resolution = 512
|
90 |
save_dir = "./outputs"
|
91 |
normal_transfer = NormalTransfer()
|
@@ -97,15 +96,15 @@ isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(d
|
|
97 |
|
98 |
# model initialization and loading
|
99 |
# flux
|
100 |
-
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device_0)
|
101 |
-
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, token=access_token).to(device_0)
|
102 |
# flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(device=device_0, dtype=torch.bfloat16)
|
103 |
-
flux_pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, vae=taef1, token=access_token).to(device_0)
|
104 |
-
flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model")
|
105 |
-
flux_pipe.load_lora_weights(flux_lora_ckpt_path)
|
106 |
# flux_pipe.to(device=device_0, dtype=torch.bfloat16)
|
107 |
-
torch.cuda.empty_cache()
|
108 |
-
flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(flux_pipe)
|
109 |
|
110 |
|
111 |
# lrm
|
@@ -159,7 +158,7 @@ def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", expor
|
|
159 |
|
160 |
all_mv, all_mvp, all_campos = get_render_cameras_video(
|
161 |
batch_size=1,
|
162 |
-
M=
|
163 |
radius=4.5,
|
164 |
elevation=(90, 60.0),
|
165 |
is_flexicubes=True,
|
@@ -209,28 +208,27 @@ def generate_multi_view_images(prompt, seed):
|
|
209 |
# generator = torch.manual_seed(seed)
|
210 |
generator = torch.Generator().manual_seed(seed)
|
211 |
with torch.no_grad():
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
# prompt=prompt,
|
214 |
-
# num_inference_steps=10,
|
215 |
# guidance_scale=3.5,
|
216 |
-
#
|
217 |
# width=resolution * 4,
|
218 |
# height=resolution * 2,
|
219 |
-
# output_type='np',
|
220 |
# generator=generator,
|
|
|
221 |
# good_vae=good_vae,
|
222 |
-
# )
|
223 |
-
|
224 |
-
prompt=prompt,
|
225 |
-
guidance_scale=3.5,
|
226 |
-
num_inference_steps=10,
|
227 |
-
width=resolution * 4,
|
228 |
-
height=resolution * 2,
|
229 |
-
generator=generator,
|
230 |
-
output_type="np",
|
231 |
-
good_vae=good_vae,
|
232 |
-
):
|
233 |
-
pass
|
234 |
# 返回最终的图像和种子(通过外部调用处理)
|
235 |
return img
|
236 |
|
@@ -251,7 +249,7 @@ def reconstruct_3d_model(images, prompt):
|
|
251 |
multi_view_mask = get_background(normal_multi_view)
|
252 |
rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
|
253 |
input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device_1)
|
254 |
-
vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=
|
255 |
# local normal to global normal
|
256 |
|
257 |
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
|
@@ -307,19 +305,81 @@ def reconstruct_3d_model(images, prompt):
|
|
307 |
# Gradio 接口函数
|
308 |
@spaces.GPU
|
309 |
def gradio_pipeline(prompt, seed):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
# 生成多视图图像
|
311 |
-
rgb_normal_grid = generate_multi_view_images(prompt, seed)
|
312 |
-
|
|
|
313 |
|
314 |
# 3d reconstruction
|
315 |
|
316 |
|
317 |
# 重建 3D 模型并返回 glb 路径
|
318 |
save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt)
|
319 |
-
|
320 |
return image_preview, save_glb_addr
|
321 |
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import shlex
|
5 |
import spaces
|
6 |
import torch
|
|
|
7 |
access_token = os.getenv("HUGGINGFACE_TOKEN")
|
8 |
subprocess.run(
|
9 |
shlex.split(
|
10 |
+
"pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html"
|
11 |
)
|
12 |
)
|
13 |
|
|
|
19 |
|
20 |
subprocess.run(
|
21 |
shlex.split(
|
22 |
+
"pip install ./extension/renderutils_plugin-0.1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
|
23 |
)
|
24 |
)
|
25 |
def install_cuda_toolkit():
|
|
|
40 |
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
|
41 |
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
|
42 |
print("==> finfish install")
|
43 |
+
install_cuda_toolkit()
|
44 |
@spaces.GPU
|
45 |
def check_gpu():
|
46 |
os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
|
|
|
83 |
|
84 |
from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
|
85 |
|
86 |
+
device_0 = "cuda"
|
87 |
+
device_1 = "cuda"
|
88 |
resolution = 512
|
89 |
save_dir = "./outputs"
|
90 |
normal_transfer = NormalTransfer()
|
|
|
96 |
|
97 |
# model initialization and loading
|
98 |
# flux
|
99 |
+
# # taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device_0)
|
100 |
+
# # good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, token=access_token).to(device_0)
|
101 |
# flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(device=device_0, dtype=torch.bfloat16)
|
102 |
+
# # flux_pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, vae=taef1, token=access_token).to(device_0)
|
103 |
+
# flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
|
104 |
+
# flux_pipe.load_lora_weights(flux_lora_ckpt_path)
|
105 |
# flux_pipe.to(device=device_0, dtype=torch.bfloat16)
|
106 |
+
# torch.cuda.empty_cache()
|
107 |
+
# flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(flux_pipe)
|
108 |
|
109 |
|
110 |
# lrm
|
|
|
158 |
|
159 |
all_mv, all_mvp, all_campos = get_render_cameras_video(
|
160 |
batch_size=1,
|
161 |
+
M=24,
|
162 |
radius=4.5,
|
163 |
elevation=(90, 60.0),
|
164 |
is_flexicubes=True,
|
|
|
208 |
# generator = torch.manual_seed(seed)
|
209 |
generator = torch.Generator().manual_seed(seed)
|
210 |
with torch.no_grad():
|
211 |
+
img = flux_pipe(
|
212 |
+
prompt=prompt,
|
213 |
+
num_inference_steps=5,
|
214 |
+
guidance_scale=3.5,
|
215 |
+
num_images_per_prompt=1,
|
216 |
+
width=resolution * 2,
|
217 |
+
height=resolution * 1,
|
218 |
+
output_type='np',
|
219 |
+
generator=generator,
|
220 |
+
).images
|
221 |
+
# for img in flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
222 |
# prompt=prompt,
|
|
|
223 |
# guidance_scale=3.5,
|
224 |
+
# num_inference_steps=4,
|
225 |
# width=resolution * 4,
|
226 |
# height=resolution * 2,
|
|
|
227 |
# generator=generator,
|
228 |
+
# output_type="np",
|
229 |
# good_vae=good_vae,
|
230 |
+
# ):
|
231 |
+
# pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
# 返回最终的图像和种子(通过外部调用处理)
|
233 |
return img
|
234 |
|
|
|
249 |
multi_view_mask = get_background(normal_multi_view)
|
250 |
rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
|
251 |
input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device_1)
|
252 |
+
vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=True)
|
253 |
# local normal to global normal
|
254 |
|
255 |
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
|
|
|
305 |
# Gradio 接口函数
|
306 |
@spaces.GPU
|
307 |
def gradio_pipeline(prompt, seed):
|
308 |
+
import ctypes
|
309 |
+
# 显式加载 libnvrtc.so.12
|
310 |
+
cuda_lib_path = "/usr/local/cuda-12.1/lib64/libnvrtc.so.12"
|
311 |
+
try:
|
312 |
+
ctypes.CDLL(cuda_lib_path, mode=ctypes.RTLD_GLOBAL)
|
313 |
+
print(f"Successfully preloaded {cuda_lib_path}")
|
314 |
+
except OSError as e:
|
315 |
+
print(f"Failed to preload {cuda_lib_path}: {e}")
|
316 |
# 生成多视图图像
|
317 |
+
# rgb_normal_grid = generate_multi_view_images(prompt, seed)
|
318 |
+
rgb_normal_grid = np.load("rgb_normal_grid.npy")
|
319 |
+
image_preview = Image.fromarray((rgb_normal_grid[0] * 255).astype(np.uint8))
|
320 |
|
321 |
# 3d reconstruction
|
322 |
|
323 |
|
324 |
# 重建 3D 模型并返回 glb 路径
|
325 |
save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt)
|
326 |
+
# save_glb_addr = None
|
327 |
return image_preview, save_glb_addr
|
328 |
|
329 |
+
# Gradio Blocks 应用
|
330 |
+
with gr.Blocks() as demo:
|
331 |
+
with gr.Row(variant="panel"):
|
332 |
+
# 左侧输入区域
|
333 |
+
with gr.Column():
|
334 |
+
with gr.Row():
|
335 |
+
prompt_input = gr.Textbox(
|
336 |
+
label="Enter Prompt",
|
337 |
+
placeholder="Describe your 3D model...",
|
338 |
+
lines=2,
|
339 |
+
elem_id="prompt_input"
|
340 |
+
)
|
341 |
+
|
342 |
+
with gr.Row():
|
343 |
+
sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
|
344 |
+
|
345 |
+
with gr.Row():
|
346 |
+
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
347 |
+
|
348 |
+
with gr.Row(variant="panel"):
|
349 |
+
gr.Markdown("Examples:")
|
350 |
+
gr.Examples(
|
351 |
+
examples=[
|
352 |
+
["a castle on a hill"],
|
353 |
+
["an owl wearing a hat"],
|
354 |
+
["a futuristic car"]
|
355 |
+
],
|
356 |
+
inputs=[prompt_input],
|
357 |
+
label="Prompt Examples"
|
358 |
+
)
|
359 |
+
|
360 |
+
# 右侧输出区域
|
361 |
+
with gr.Column():
|
362 |
+
with gr.Row():
|
363 |
+
rgb_normal_grid_image = gr.Image(
|
364 |
+
label="RGB Normal Grid",
|
365 |
+
type="pil",
|
366 |
+
interactive=False
|
367 |
+
)
|
368 |
+
|
369 |
+
with gr.Row():
|
370 |
+
with gr.Tab("GLB"):
|
371 |
+
output_glb_model = gr.Model3D(
|
372 |
+
label="Generated 3D Model (GLB Format)",
|
373 |
+
interactive=False
|
374 |
+
)
|
375 |
+
gr.Markdown("Download the model for proper visualization.")
|
376 |
+
|
377 |
+
# 处理逻辑
|
378 |
+
submit.click(
|
379 |
+
fn=gradio_pipeline, inputs=[prompt_input, sample_seed],
|
380 |
+
outputs=[rgb_normal_grid_image, output_glb_model]
|
381 |
+
)
|
382 |
+
|
383 |
+
# 启动应用
|
384 |
+
# demo.queue(max_size=10)
|
385 |
+
demo.launch()
|
app_flux.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import numpy as np
|
3 |
-
import os
|
4 |
-
import random
|
5 |
-
import spaces
|
6 |
-
import torch
|
7 |
-
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
|
8 |
-
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
|
9 |
-
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
|
10 |
-
|
11 |
-
dtype = torch.bfloat16
|
12 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
-
access_token = os.getenv("HUGGINGFACE_TOKEN")
|
14 |
-
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
15 |
-
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype, token=access_token).to(device)
|
16 |
-
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1, token=access_token).to(device)
|
17 |
-
torch.cuda.empty_cache()
|
18 |
-
|
19 |
-
MAX_SEED = np.iinfo(np.int32).max
|
20 |
-
MAX_IMAGE_SIZE = 2048
|
21 |
-
|
22 |
-
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
23 |
-
|
24 |
-
@spaces.GPU(duration=75)
|
25 |
-
def infer(prompt, seed=42, randomize_seed=False, width=2048, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
|
26 |
-
if randomize_seed:
|
27 |
-
seed = random.randint(0, MAX_SEED)
|
28 |
-
generator = torch.Generator().manual_seed(seed)
|
29 |
-
|
30 |
-
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
31 |
-
prompt=prompt,
|
32 |
-
guidance_scale=guidance_scale,
|
33 |
-
num_inference_steps=num_inference_steps,
|
34 |
-
width=width,
|
35 |
-
height=height,
|
36 |
-
generator=generator,
|
37 |
-
output_type="pil",
|
38 |
-
good_vae=good_vae,
|
39 |
-
):
|
40 |
-
# yield img, seed
|
41 |
-
pass
|
42 |
-
return img, seed
|
43 |
-
examples = [
|
44 |
-
"a tiny astronaut hatching from an egg on the moon",
|
45 |
-
"a cat holding a sign that says hello world",
|
46 |
-
"an anime illustration of a wiener schnitzel",
|
47 |
-
]
|
48 |
-
|
49 |
-
css="""
|
50 |
-
#col-container {
|
51 |
-
margin: 0 auto;
|
52 |
-
max-width: 520px;
|
53 |
-
}
|
54 |
-
"""
|
55 |
-
|
56 |
-
with gr.Blocks(css=css) as demo:
|
57 |
-
|
58 |
-
with gr.Column(elem_id="col-container"):
|
59 |
-
gr.Markdown(f"""# FLUX.1 [dev]
|
60 |
-
12B param rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
|
61 |
-
[[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
|
62 |
-
""")
|
63 |
-
|
64 |
-
with gr.Row():
|
65 |
-
|
66 |
-
prompt = gr.Text(
|
67 |
-
label="Prompt",
|
68 |
-
show_label=False,
|
69 |
-
max_lines=1,
|
70 |
-
placeholder="Enter your prompt",
|
71 |
-
container=False,
|
72 |
-
)
|
73 |
-
|
74 |
-
run_button = gr.Button("Run", scale=0)
|
75 |
-
|
76 |
-
result = gr.Image(label="Result", show_label=False)
|
77 |
-
|
78 |
-
with gr.Accordion("Advanced Settings", open=False):
|
79 |
-
|
80 |
-
seed = gr.Slider(
|
81 |
-
label="Seed",
|
82 |
-
minimum=0,
|
83 |
-
maximum=MAX_SEED,
|
84 |
-
step=1,
|
85 |
-
value=0,
|
86 |
-
)
|
87 |
-
|
88 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
89 |
-
|
90 |
-
with gr.Row():
|
91 |
-
|
92 |
-
width = gr.Slider(
|
93 |
-
label="Width",
|
94 |
-
minimum=256,
|
95 |
-
maximum=MAX_IMAGE_SIZE,
|
96 |
-
step=32,
|
97 |
-
value=1024,
|
98 |
-
)
|
99 |
-
|
100 |
-
height = gr.Slider(
|
101 |
-
label="Height",
|
102 |
-
minimum=256,
|
103 |
-
maximum=MAX_IMAGE_SIZE,
|
104 |
-
step=32,
|
105 |
-
value=1024,
|
106 |
-
)
|
107 |
-
|
108 |
-
with gr.Row():
|
109 |
-
|
110 |
-
guidance_scale = gr.Slider(
|
111 |
-
label="Guidance Scale",
|
112 |
-
minimum=1,
|
113 |
-
maximum=15,
|
114 |
-
step=0.1,
|
115 |
-
value=3.5,
|
116 |
-
)
|
117 |
-
|
118 |
-
num_inference_steps = gr.Slider(
|
119 |
-
label="Number of inference steps",
|
120 |
-
minimum=1,
|
121 |
-
maximum=50,
|
122 |
-
step=1,
|
123 |
-
value=28,
|
124 |
-
)
|
125 |
-
|
126 |
-
gr.Examples(
|
127 |
-
examples = examples,
|
128 |
-
fn = infer,
|
129 |
-
inputs = [prompt],
|
130 |
-
outputs = [result, seed],
|
131 |
-
cache_examples="lazy"
|
132 |
-
)
|
133 |
-
|
134 |
-
gr.on(
|
135 |
-
triggers=[run_button.click, prompt.submit],
|
136 |
-
fn = infer,
|
137 |
-
inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
|
138 |
-
outputs = [result, seed]
|
139 |
-
)
|
140 |
-
|
141 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_to_mesh.py
DELETED
@@ -1,437 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from einops import rearrange
|
3 |
-
from omegaconf import OmegaConf
|
4 |
-
import torch
|
5 |
-
import numpy as np
|
6 |
-
import trimesh
|
7 |
-
import torchvision
|
8 |
-
import torch.nn.functional as F
|
9 |
-
from PIL import Image
|
10 |
-
from torchvision import transforms
|
11 |
-
from torchvision.transforms import v2
|
12 |
-
from transformers import AutoProcessor, AutoModelForCausalLM
|
13 |
-
import rembg
|
14 |
-
from diffusers import FluxPipeline, FluxControlNetImg2ImgPipeline
|
15 |
-
from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
16 |
-
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler
|
17 |
-
from pytorch_lightning import seed_everything
|
18 |
-
import os
|
19 |
-
|
20 |
-
from models.ISOMER.reconstruction_func import reconstruction
|
21 |
-
from models.ISOMER.projection_func import projection
|
22 |
-
from models.lrm.utils.infer_util import remove_background, resize_foreground, save_video
|
23 |
-
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
|
24 |
-
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
25 |
-
from models.lrm.utils.train_util import instantiate_from_config
|
26 |
-
from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
|
27 |
-
from utils.tool import NormalTransfer, get_render_cameras_frames, load_mipmap
|
28 |
-
from utils.tool import get_background, get_render_cameras_video, render_frames
|
29 |
-
import time
|
30 |
-
|
31 |
-
device = "cuda"
|
32 |
-
resolution = 512
|
33 |
-
save_dir = "./outputs"
|
34 |
-
zero123plus_diffusion_steps = 75
|
35 |
-
normal_transfer = NormalTransfer()
|
36 |
-
rembg_session = rembg.new_session()
|
37 |
-
isomer_azimuths = torch.from_numpy(np.array([270, 0, 90, 180])).to(device)
|
38 |
-
isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).to(device)
|
39 |
-
isomer_radius = 4.1
|
40 |
-
isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
|
41 |
-
isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
|
42 |
-
# seed_everything(42)
|
43 |
-
|
44 |
-
# model initialization and loading
|
45 |
-
# flux
|
46 |
-
print('==> Loading Flux model ...')
|
47 |
-
flux_base_model_pth = "/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev"
|
48 |
-
flux_controlnet = FluxControlNetModel.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/flux_controlnets/FLUX.1-dev-ControlNet-Union-Pro")
|
49 |
-
flux_pipe = FluxControlNetImg2ImgPipeline.from_pretrained(flux_base_model_pth, controlnet=[flux_controlnet], torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
|
50 |
-
|
51 |
-
flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors')
|
52 |
-
|
53 |
-
|
54 |
-
flux_pipe.to(device=device, dtype=torch.bfloat16)
|
55 |
-
generator = torch.Generator(device=device).manual_seed(0)
|
56 |
-
|
57 |
-
# lrm
|
58 |
-
print('==> Loading LRM model ...')
|
59 |
-
config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
|
60 |
-
model_config = config.model_config
|
61 |
-
infer_config = config.infer_config
|
62 |
-
model = instantiate_from_config(model_config)
|
63 |
-
model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt"
|
64 |
-
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
65 |
-
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
66 |
-
model.load_state_dict(state_dict, strict=True)
|
67 |
-
|
68 |
-
model = model.to(device)
|
69 |
-
model.init_flexicubes_geometry(device, fovy=50.0)
|
70 |
-
model = model.eval()
|
71 |
-
|
72 |
-
# zero123++
|
73 |
-
print('==> Loading diffusion model ...')
|
74 |
-
zero123plus_pipeline = DiffusionPipeline.from_pretrained(
|
75 |
-
"sudo-ai/zero123plus-v1.2",
|
76 |
-
custom_pipeline="./models/zero123plus",
|
77 |
-
torch_dtype=torch.float16,
|
78 |
-
)
|
79 |
-
zero123plus_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
80 |
-
zero123plus_pipeline.scheduler.config, timestep_spacing='trailing'
|
81 |
-
)
|
82 |
-
unet_ckpt_path = "./checkpoint/zero123++/flexgen_19w.ckpt"
|
83 |
-
state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
|
84 |
-
state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
|
85 |
-
zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
|
86 |
-
zero123plus_pipeline = zero123plus_pipeline.to(device)
|
87 |
-
|
88 |
-
# unet_ckpt_path = "checkpoint/zero123++/diffusion_pytorch_model.bin"
|
89 |
-
# state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
90 |
-
# zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
|
91 |
-
# zero123plus_pipeline = zero123plus_pipeline.to(device)
|
92 |
-
|
93 |
-
# florence
|
94 |
-
caption_model = AutoModelForCausalLM.from_pretrained(
|
95 |
-
"/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", torch_dtype=torch.bfloat16, trust_remote_code=True,
|
96 |
-
).to(device)
|
97 |
-
caption_processor = AutoProcessor.from_pretrained("/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", trust_remote_code=True)
|
98 |
-
|
99 |
-
# Flux multi-view generation
|
100 |
-
def multi_view_rgb_normal_generation_with_controlnet(prompt, image, strength=1.0,
|
101 |
-
control_image=[],
|
102 |
-
control_mode=[],
|
103 |
-
control_guidance_start=None,
|
104 |
-
control_guidance_end=None,
|
105 |
-
controlnet_conditioning_scale=None,
|
106 |
-
lora_scale=1.0
|
107 |
-
):
|
108 |
-
control_mode_dict = {
|
109 |
-
'canny': 0,
|
110 |
-
'tile': 1,
|
111 |
-
'depth': 2,
|
112 |
-
'blur': 3,
|
113 |
-
'pose': 4,
|
114 |
-
'gray': 5,
|
115 |
-
'lq': 6,
|
116 |
-
} # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
|
117 |
-
|
118 |
-
hparam_dict = {
|
119 |
-
'prompt': prompt,
|
120 |
-
'image': image,
|
121 |
-
'strength': strength,
|
122 |
-
'num_inference_steps': 30,
|
123 |
-
'guidance_scale': 3.5,
|
124 |
-
'num_images_per_prompt': 1,
|
125 |
-
'width': resolution*4,
|
126 |
-
'height': resolution*2,
|
127 |
-
'output_type': 'np',
|
128 |
-
'generator': generator,
|
129 |
-
'joint_attention_kwargs': {"scale": lora_scale}
|
130 |
-
}
|
131 |
-
|
132 |
-
# append controlnet hparams
|
133 |
-
if len(control_image) > 0:
|
134 |
-
assert len(control_mode) == len(control_image) # the count of image should be the same as control mode
|
135 |
-
|
136 |
-
ctrl_hparams = {
|
137 |
-
'control_mode': [control_mode_dict[mode_] for mode_ in control_mode],
|
138 |
-
'control_image': control_image,
|
139 |
-
'control_guidance_start': control_guidance_start or [0.0 for i in range(len(control_image))],
|
140 |
-
'control_guidance_end': control_guidance_end or [1.0 for i in range(len(control_image))],
|
141 |
-
'controlnet_conditioning_scale': controlnet_conditioning_scale or [1.0 for i in range(len(control_image))],
|
142 |
-
}
|
143 |
-
|
144 |
-
hparam_dict.update(ctrl_hparams)
|
145 |
-
|
146 |
-
# generate multi-view images
|
147 |
-
with torch.no_grad():
|
148 |
-
image = flux_pipe(
|
149 |
-
**hparam_dict
|
150 |
-
).images
|
151 |
-
return image
|
152 |
-
|
153 |
-
# captioning
|
154 |
-
def run_captioning(image):
|
155 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
156 |
-
torch_dtype = torch.bfloat16
|
157 |
-
|
158 |
-
if isinstance(image, str): # If image is a file path
|
159 |
-
image = Image.open(image).convert("RGB")
|
160 |
-
|
161 |
-
prompt = "<MORE_DETAILED_CAPTION>"
|
162 |
-
inputs = caption_processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
|
163 |
-
# print(f"inputs {inputs}")
|
164 |
-
|
165 |
-
generated_ids = caption_model.generate(
|
166 |
-
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
|
167 |
-
)
|
168 |
-
|
169 |
-
generated_text = caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
170 |
-
parsed_answer = caption_processor.post_process_generation(
|
171 |
-
generated_text, task=prompt, image_size=(image.width, image.height)
|
172 |
-
)
|
173 |
-
# print(f"parsed_answer = {parsed_answer}")
|
174 |
-
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"].replace("The image is ", "")
|
175 |
-
return caption_text
|
176 |
-
|
177 |
-
|
178 |
-
# zero123++ multi-view generation
|
179 |
-
def multi_view_rgb_generation(cond_img):
|
180 |
-
# generate multi-view images
|
181 |
-
with torch.no_grad():
|
182 |
-
output_image = zero123plus_pipeline(
|
183 |
-
cond_img,
|
184 |
-
num_inference_steps=zero123plus_diffusion_steps,
|
185 |
-
width=resolution*2,
|
186 |
-
height=resolution*2,
|
187 |
-
).images[0]
|
188 |
-
return output_image
|
189 |
-
|
190 |
-
# lrm reconstructions
|
191 |
-
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False, render_azimuths=None, render_elevations=None, render_radius=None, render_fov=30):
|
192 |
-
images = image.unsqueeze(0).to(device)
|
193 |
-
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
|
194 |
-
# breakpoint()
|
195 |
-
with torch.no_grad():
|
196 |
-
# get triplane
|
197 |
-
planes = model.forward_planes(images, input_cameras)
|
198 |
-
|
199 |
-
mesh_path_idx = os.path.join(save_path, f'{name}.obj')
|
200 |
-
|
201 |
-
mesh_out = model.extract_mesh(
|
202 |
-
planes,
|
203 |
-
use_texture_map=export_texmap,
|
204 |
-
**infer_config,
|
205 |
-
)
|
206 |
-
if export_texmap:
|
207 |
-
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
208 |
-
save_obj_with_mtl(
|
209 |
-
vertices.data.cpu().numpy(),
|
210 |
-
uvs.data.cpu().numpy(),
|
211 |
-
faces.data.cpu().numpy(),
|
212 |
-
mesh_tex_idx.data.cpu().numpy(),
|
213 |
-
tex_map.permute(1, 2, 0).data.cpu().numpy(),
|
214 |
-
mesh_path_idx,
|
215 |
-
)
|
216 |
-
else:
|
217 |
-
vertices, faces, vertex_colors = mesh_out
|
218 |
-
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
|
219 |
-
print(f"Mesh saved to {mesh_path_idx}")
|
220 |
-
|
221 |
-
render_size = 512
|
222 |
-
if if_save_video:
|
223 |
-
video_path_idx = os.path.join(save_path, f'{name}.mp4')
|
224 |
-
render_size = infer_config.render_resolution
|
225 |
-
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
226 |
-
materials = (0.0,0.9)
|
227 |
-
|
228 |
-
all_mv, all_mvp, all_campos = get_render_cameras_video(
|
229 |
-
batch_size=1,
|
230 |
-
M=240,
|
231 |
-
radius=4.5,
|
232 |
-
elevation=(90, 60.0),
|
233 |
-
is_flexicubes=True,
|
234 |
-
fov=30
|
235 |
-
)
|
236 |
-
|
237 |
-
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
238 |
-
model,
|
239 |
-
planes,
|
240 |
-
render_cameras=all_mvp,
|
241 |
-
camera_pos=all_campos,
|
242 |
-
env=ENV,
|
243 |
-
materials=materials,
|
244 |
-
render_size=render_size,
|
245 |
-
chunk_size=20,
|
246 |
-
is_flexicubes=True,
|
247 |
-
)
|
248 |
-
normals = (torch.nn.functional.normalize(normals) + 1) / 2
|
249 |
-
normals = normals * alphas + (1-alphas)
|
250 |
-
all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
|
251 |
-
|
252 |
-
# breakpoint()
|
253 |
-
save_video(
|
254 |
-
all_frames,
|
255 |
-
video_path_idx,
|
256 |
-
fps=30,
|
257 |
-
)
|
258 |
-
print(f"Video saved to {video_path_idx}")
|
259 |
-
|
260 |
-
if render_azimuths is not None and render_elevations is not None and render_radius is not None:
|
261 |
-
render_size = infer_config.render_resolution
|
262 |
-
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
263 |
-
materials = (0.0,0.9)
|
264 |
-
all_mv, all_mvp, all_campos, identity_mv = get_render_cameras_frames(
|
265 |
-
batch_size=1,
|
266 |
-
radius=render_radius,
|
267 |
-
azimuths=render_azimuths,
|
268 |
-
elevations=render_elevations,
|
269 |
-
fov=30
|
270 |
-
)
|
271 |
-
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
272 |
-
model,
|
273 |
-
planes,
|
274 |
-
render_cameras=all_mvp,
|
275 |
-
camera_pos=all_campos,
|
276 |
-
env=ENV,
|
277 |
-
materials=materials,
|
278 |
-
render_size=render_size,
|
279 |
-
render_mv = all_mv,
|
280 |
-
local_normal=True,
|
281 |
-
identity_mv=identity_mv,
|
282 |
-
)
|
283 |
-
else:
|
284 |
-
normals = None
|
285 |
-
frames = None
|
286 |
-
albedos = None
|
287 |
-
|
288 |
-
return vertices, faces, normals, frames, albedos
|
289 |
-
|
290 |
-
|
291 |
-
def transform_normal(input_normal, azimuths_deg, elevations_deg, radius=4.5, is_global_to_local=False):
|
292 |
-
"""
|
293 |
-
input_normal: in range [-1, 1], shape (b c h w)
|
294 |
-
"""
|
295 |
-
|
296 |
-
input_normal = input_normal.permute(0, 2, 3, 1).cpu()
|
297 |
-
|
298 |
-
azimuths_deg = np.array(azimuths_deg)
|
299 |
-
elevations_deg = np.array(elevations_deg)
|
300 |
-
|
301 |
-
if is_global_to_local:
|
302 |
-
local_normal = normal_transfer.trans_global_2_local(input_normal, azimuths_deg, elevations_deg)
|
303 |
-
return local_normal.permute(0, 3, 1, 2)
|
304 |
-
else:
|
305 |
-
global_normal = normal_transfer.trans_local_2_global(input_normal, azimuths_deg, elevations_deg, radius=radius, for_lotus=False)
|
306 |
-
global_normal[..., 0] *= -1
|
307 |
-
return global_normal.permute(0, 3, 1, 2)
|
308 |
-
|
309 |
-
def local_normal_global_transform(local_normal_images,azimuths_deg,elevations_deg):
|
310 |
-
if local_normal_images.min() >= 0:
|
311 |
-
local_normal = local_normal_images.float() * 2 - 1
|
312 |
-
else:
|
313 |
-
local_normal = local_normal_images.float()
|
314 |
-
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
|
315 |
-
global_normal[...,0] *= -1
|
316 |
-
global_normal = (global_normal + 1) / 2
|
317 |
-
global_normal = global_normal.permute(0, 3, 1, 2)
|
318 |
-
return global_normal
|
319 |
-
|
320 |
-
def main():
|
321 |
-
image_pth = "examples/蓝色小怪物.webp"
|
322 |
-
save_dir_path = os.path.join(save_dir, image_pth.split("/")[-1].split(".")[0])
|
323 |
-
os.makedirs(save_dir_path, exist_ok=True)
|
324 |
-
input_image = Image.open(image_pth)
|
325 |
-
# if not args.no_rembg:
|
326 |
-
input_image = remove_background(input_image, rembg_session)
|
327 |
-
input_image = resize_foreground(input_image, 0.85)
|
328 |
-
|
329 |
-
# generate caption
|
330 |
-
image_caption = run_captioning(image_pth)
|
331 |
-
|
332 |
-
# generate multi-view images
|
333 |
-
output_image = multi_view_rgb_generation(input_image)
|
334 |
-
|
335 |
-
# lrm reconstructions
|
336 |
-
rgb_multi_view = np.asarray(output_image, dtype=np.float32) / 255.0
|
337 |
-
rgb_multi_view = torch.from_numpy(rgb_multi_view).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
|
338 |
-
rgb_multi_view = rearrange(rgb_multi_view, 'c (n h) (m w) -> (n m) c h w', n=2, m=2) # (8, 3, 512, 512)
|
339 |
-
|
340 |
-
input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
|
341 |
-
|
342 |
-
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \
|
343 |
-
lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm',
|
344 |
-
export_texmap=False, if_save_video=False, render_azimuths=isomer_azimuths,
|
345 |
-
render_elevations=isomer_elevations, render_radius=isomer_radius, render_fov=30)
|
346 |
-
|
347 |
-
vertices = torch.from_numpy(vertices).to(device)
|
348 |
-
faces = torch.from_numpy(faces).to(device)
|
349 |
-
vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
|
350 |
-
vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
|
351 |
-
|
352 |
-
|
353 |
-
# lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
|
354 |
-
lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([rgb_multi_view[[3,0,1,2]].cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
|
355 |
-
# rgb_multi_view[[3,0,1,2]] : (B,3,H,W)
|
356 |
-
# lrm_multi_view_normals : (B,3,H,W)
|
357 |
-
# combined_images = 0.5 * rgb_multi_view[[3,0,1,2]].cpu() + 0.5 * (lrm_multi_view_normals.cpu() + 1) / 2
|
358 |
-
# torchvision.utils.save_image(combined_images, os.path.join("debug_output", 'combined.png'))
|
359 |
-
# breakpoint()
|
360 |
-
# Use the low-quality controlnet by default, feel free to try the others
|
361 |
-
control_image = [lrm_3D_bundle_image * 2 - 1]
|
362 |
-
control_mode = ['tile']
|
363 |
-
control_guidance_start = [0.0]
|
364 |
-
control_guidance_end = [0.3]
|
365 |
-
controlnet_conditioning_scale = [0.8]
|
366 |
-
|
367 |
-
flux_pipe.controlnet = FluxMultiControlNetModel([flux_controlnet for _ in control_mode])
|
368 |
-
# breakpoint()
|
369 |
-
rgb_normal_grid = multi_view_rgb_normal_generation_with_controlnet(
|
370 |
-
prompt= ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', image_caption]),
|
371 |
-
image=lrm_3D_bundle_image,
|
372 |
-
strength=0.6,
|
373 |
-
control_image=control_image,
|
374 |
-
control_mode=control_mode,
|
375 |
-
control_guidance_start=control_guidance_start,
|
376 |
-
control_guidance_end=control_guidance_end,
|
377 |
-
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
378 |
-
lora_scale=1.0
|
379 |
-
) # noted that rgb_normal_grid is a (b, h, w, c) numpy array
|
380 |
-
|
381 |
-
rgb_normal_grid = torch.from_numpy(rgb_normal_grid).contiguous().float()
|
382 |
-
rgb_normal_grid = rearrange(rgb_normal_grid.squeeze(0), '(n h) (m w) c-> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
|
383 |
-
rgb_multi_view = rgb_normal_grid[:4, :3, :, :].cuda()
|
384 |
-
normal_multi_view = rgb_normal_grid[4:, :3, :, :].cuda()
|
385 |
-
multi_view_mask = get_background(normal_multi_view).cuda()
|
386 |
-
rgb_multi_view = rgb_multi_view * multi_view_mask + (1-multi_view_mask)
|
387 |
-
|
388 |
-
# local normal to global normal
|
389 |
-
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1).cpu(), isomer_azimuths, isomer_elevations).cuda()
|
390 |
-
|
391 |
-
global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
|
392 |
-
|
393 |
-
global_normal = global_normal.permute(0,2,3,1)
|
394 |
-
multi_view_mask = multi_view_mask.squeeze(1)
|
395 |
-
rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
|
396 |
-
# global_normal: B,H,W,3
|
397 |
-
# multi_view_mask: B,H,W
|
398 |
-
# rgb_multi_view: B,H,W,3
|
399 |
-
|
400 |
-
|
401 |
-
meshes = reconstruction(
|
402 |
-
normal_pils=global_normal,
|
403 |
-
masks=multi_view_mask,
|
404 |
-
weights=isomer_geo_weights,
|
405 |
-
fov=30,
|
406 |
-
radius=isomer_radius,
|
407 |
-
camera_angles_azi=isomer_azimuths,
|
408 |
-
camera_angles_ele=isomer_elevations,
|
409 |
-
expansion_weight_stage1=0.1,
|
410 |
-
init_type="file",
|
411 |
-
init_verts=vertices,
|
412 |
-
init_faces=faces,
|
413 |
-
stage1_steps=0,
|
414 |
-
stage2_steps=50,
|
415 |
-
start_edge_len_stage1=0.1,
|
416 |
-
end_edge_len_stage1=0.02,
|
417 |
-
start_edge_len_stage2=0.02,
|
418 |
-
end_edge_len_stage2=0.005,
|
419 |
-
)
|
420 |
-
|
421 |
-
save_glb_addr = projection(
|
422 |
-
meshes=meshes,
|
423 |
-
masks=multi_view_mask,
|
424 |
-
images=rgb_multi_view,
|
425 |
-
azimuths=isomer_azimuths,
|
426 |
-
elevations=isomer_elevations,
|
427 |
-
weights=isomer_color_weights,
|
428 |
-
fov=30,
|
429 |
-
radius=isomer_radius,
|
430 |
-
save_dir=f"{save_dir_path}/ISOMER/",
|
431 |
-
)
|
432 |
-
print(f'saved to {save_glb_addr}')
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
if __name__ == '__main__':
|
437 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_to_mesh_new.py
DELETED
@@ -1,436 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from einops import rearrange
|
3 |
-
from omegaconf import OmegaConf
|
4 |
-
import torch
|
5 |
-
import numpy as np
|
6 |
-
import trimesh
|
7 |
-
import torchvision
|
8 |
-
import torch.nn.functional as F
|
9 |
-
from PIL import Image
|
10 |
-
from torchvision import transforms
|
11 |
-
from torchvision.transforms import v2
|
12 |
-
from transformers import AutoProcessor, AutoModelForCausalLM
|
13 |
-
import rembg
|
14 |
-
from diffusers import FluxPipeline, FluxControlNetImg2ImgPipeline
|
15 |
-
from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
16 |
-
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler
|
17 |
-
from pytorch_lightning import seed_everything
|
18 |
-
import os
|
19 |
-
|
20 |
-
from models.ISOMER.reconstruction_func import reconstruction
|
21 |
-
from models.ISOMER.projection_func import projection
|
22 |
-
from models.lrm.utils.infer_util import remove_background, resize_foreground, save_video
|
23 |
-
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
|
24 |
-
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
25 |
-
from models.lrm.utils.train_util import instantiate_from_config
|
26 |
-
from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
|
27 |
-
from utils.tool import NormalTransfer, get_render_cameras_frames, load_mipmap
|
28 |
-
from utils.tool import get_background, get_render_cameras_video, render_frames, mask_fix
|
29 |
-
|
30 |
-
device = "cuda"
|
31 |
-
resolution = 512
|
32 |
-
save_dir = "./outputs"
|
33 |
-
zero123plus_diffusion_steps = 75
|
34 |
-
normal_transfer = NormalTransfer()
|
35 |
-
rembg_session = rembg.new_session()
|
36 |
-
isomer_azimuths = torch.from_numpy(np.array([270, 0, 90, 180])).to(device)
|
37 |
-
isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).to(device)
|
38 |
-
isomer_radius = 4.1
|
39 |
-
isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
|
40 |
-
isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
|
41 |
-
# seed_everything(42)
|
42 |
-
|
43 |
-
# model initialization and loading
|
44 |
-
# flux
|
45 |
-
print('==> Loading Flux model ...')
|
46 |
-
flux_base_model_pth = "/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev"
|
47 |
-
flux_controlnet = FluxControlNetModel.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/flux_controlnets/FLUX.1-dev-ControlNet-Union-Pro")
|
48 |
-
flux_pipe = FluxControlNetImg2ImgPipeline.from_pretrained(flux_base_model_pth, controlnet=[flux_controlnet], torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
|
49 |
-
|
50 |
-
flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors')
|
51 |
-
|
52 |
-
|
53 |
-
flux_pipe.to(device=device, dtype=torch.bfloat16)
|
54 |
-
generator = torch.Generator(device=device).manual_seed(0)
|
55 |
-
|
56 |
-
# lrm
|
57 |
-
print('==> Loading LRM model ...')
|
58 |
-
config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
|
59 |
-
model_config = config.model_config
|
60 |
-
infer_config = config.infer_config
|
61 |
-
model = instantiate_from_config(model_config)
|
62 |
-
model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt"
|
63 |
-
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
64 |
-
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
65 |
-
model.load_state_dict(state_dict, strict=True)
|
66 |
-
|
67 |
-
model = model.to(device)
|
68 |
-
model.init_flexicubes_geometry(device, fovy=50.0)
|
69 |
-
model = model.eval()
|
70 |
-
|
71 |
-
# zero123++
|
72 |
-
print('==> Loading diffusion model ...')
|
73 |
-
zero123plus_pipeline = DiffusionPipeline.from_pretrained(
|
74 |
-
"sudo-ai/zero123plus-v1.2",
|
75 |
-
custom_pipeline="./models/zero123plus",
|
76 |
-
torch_dtype=torch.float16,
|
77 |
-
)
|
78 |
-
zero123plus_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
79 |
-
zero123plus_pipeline.scheduler.config, timestep_spacing='trailing'
|
80 |
-
)
|
81 |
-
unet_ckpt_path = "./checkpoint/zero123++/flexgen_19w.ckpt"
|
82 |
-
state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
|
83 |
-
state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
|
84 |
-
zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
|
85 |
-
zero123plus_pipeline = zero123plus_pipeline.to(device)
|
86 |
-
|
87 |
-
# unet_ckpt_path = "checkpoint/zero123++/diffusion_pytorch_model.bin"
|
88 |
-
# state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
89 |
-
# zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
|
90 |
-
# zero123plus_pipeline = zero123plus_pipeline.to(device)
|
91 |
-
|
92 |
-
# florence
|
93 |
-
caption_model = AutoModelForCausalLM.from_pretrained(
|
94 |
-
"/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", torch_dtype=torch.bfloat16, trust_remote_code=True,
|
95 |
-
).to(device)
|
96 |
-
caption_processor = AutoProcessor.from_pretrained("/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", trust_remote_code=True)
|
97 |
-
|
98 |
-
# Flux multi-view generation
|
99 |
-
def multi_view_rgb_normal_generation_with_controlnet(prompt, image, strength=1.0,
|
100 |
-
control_image=[],
|
101 |
-
control_mode=[],
|
102 |
-
control_guidance_start=None,
|
103 |
-
control_guidance_end=None,
|
104 |
-
controlnet_conditioning_scale=None,
|
105 |
-
lora_scale=1.0
|
106 |
-
):
|
107 |
-
control_mode_dict = {
|
108 |
-
'canny': 0,
|
109 |
-
'tile': 1,
|
110 |
-
'depth': 2,
|
111 |
-
'blur': 3,
|
112 |
-
'pose': 4,
|
113 |
-
'gray': 5,
|
114 |
-
'lq': 6,
|
115 |
-
} # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
|
116 |
-
|
117 |
-
hparam_dict = {
|
118 |
-
'prompt': prompt,
|
119 |
-
'image': image,
|
120 |
-
'strength': strength,
|
121 |
-
'num_inference_steps': 30,
|
122 |
-
'guidance_scale': 3.5,
|
123 |
-
'num_images_per_prompt': 1,
|
124 |
-
'width': resolution*4,
|
125 |
-
'height': resolution*2,
|
126 |
-
'output_type': 'np',
|
127 |
-
'generator': generator,
|
128 |
-
'joint_attention_kwargs': {"scale": lora_scale}
|
129 |
-
}
|
130 |
-
|
131 |
-
# append controlnet hparams
|
132 |
-
if len(control_image) > 0:
|
133 |
-
assert len(control_mode) == len(control_image) # the count of image should be the same as control mode
|
134 |
-
|
135 |
-
ctrl_hparams = {
|
136 |
-
'control_mode': [control_mode_dict[mode_] for mode_ in control_mode],
|
137 |
-
'control_image': control_image,
|
138 |
-
'control_guidance_start': control_guidance_start or [0.0 for i in range(len(control_image))],
|
139 |
-
'control_guidance_end': control_guidance_end or [1.0 for i in range(len(control_image))],
|
140 |
-
'controlnet_conditioning_scale': controlnet_conditioning_scale or [1.0 for i in range(len(control_image))],
|
141 |
-
}
|
142 |
-
|
143 |
-
hparam_dict.update(ctrl_hparams)
|
144 |
-
|
145 |
-
# generate multi-view images
|
146 |
-
with torch.no_grad():
|
147 |
-
image = flux_pipe(
|
148 |
-
**hparam_dict
|
149 |
-
).images
|
150 |
-
return image
|
151 |
-
|
152 |
-
# captioning
|
153 |
-
def run_captioning(image):
|
154 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
155 |
-
torch_dtype = torch.bfloat16
|
156 |
-
|
157 |
-
if isinstance(image, str): # If image is a file path
|
158 |
-
image = Image.open(image).convert("RGB")
|
159 |
-
|
160 |
-
prompt = "<MORE_DETAILED_CAPTION>"
|
161 |
-
inputs = caption_processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
|
162 |
-
# print(f"inputs {inputs}")
|
163 |
-
|
164 |
-
generated_ids = caption_model.generate(
|
165 |
-
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
|
166 |
-
)
|
167 |
-
|
168 |
-
generated_text = caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
169 |
-
parsed_answer = caption_processor.post_process_generation(
|
170 |
-
generated_text, task=prompt, image_size=(image.width, image.height)
|
171 |
-
)
|
172 |
-
# print(f"parsed_answer = {parsed_answer}")
|
173 |
-
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"].replace("The image is ", "")
|
174 |
-
return caption_text
|
175 |
-
|
176 |
-
|
177 |
-
# zero123++ multi-view generation
|
178 |
-
def multi_view_rgb_generation(cond_img):
|
179 |
-
# generate multi-view images
|
180 |
-
with torch.no_grad():
|
181 |
-
output_image = zero123plus_pipeline(
|
182 |
-
cond_img,
|
183 |
-
num_inference_steps=zero123plus_diffusion_steps,
|
184 |
-
width=resolution*2,
|
185 |
-
height=resolution*2,
|
186 |
-
).images[0]
|
187 |
-
return output_image
|
188 |
-
|
189 |
-
# lrm reconstructions
|
190 |
-
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False, render_azimuths=None, render_elevations=None, render_radius=None, render_fov=30):
|
191 |
-
images = image.unsqueeze(0).to(device)
|
192 |
-
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
|
193 |
-
# breakpoint()
|
194 |
-
with torch.no_grad():
|
195 |
-
# get triplane
|
196 |
-
planes = model.forward_planes(images, input_cameras)
|
197 |
-
|
198 |
-
mesh_path_idx = os.path.join(save_path, f'{name}.obj')
|
199 |
-
|
200 |
-
mesh_out = model.extract_mesh(
|
201 |
-
planes,
|
202 |
-
use_texture_map=export_texmap,
|
203 |
-
**infer_config,
|
204 |
-
)
|
205 |
-
if export_texmap:
|
206 |
-
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
207 |
-
save_obj_with_mtl(
|
208 |
-
vertices.data.cpu().numpy(),
|
209 |
-
uvs.data.cpu().numpy(),
|
210 |
-
faces.data.cpu().numpy(),
|
211 |
-
mesh_tex_idx.data.cpu().numpy(),
|
212 |
-
tex_map.permute(1, 2, 0).data.cpu().numpy(),
|
213 |
-
mesh_path_idx,
|
214 |
-
)
|
215 |
-
else:
|
216 |
-
vertices, faces, vertex_colors = mesh_out
|
217 |
-
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
|
218 |
-
print(f"Mesh saved to {mesh_path_idx}")
|
219 |
-
|
220 |
-
render_size = 512
|
221 |
-
if if_save_video:
|
222 |
-
video_path_idx = os.path.join(save_path, f'{name}.mp4')
|
223 |
-
render_size = infer_config.render_resolution
|
224 |
-
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
225 |
-
materials = (0.0,0.9)
|
226 |
-
|
227 |
-
all_mv, all_mvp, all_campos = get_render_cameras_video(
|
228 |
-
batch_size=1,
|
229 |
-
M=240,
|
230 |
-
radius=4.5,
|
231 |
-
elevation=(90, 60.0),
|
232 |
-
is_flexicubes=True,
|
233 |
-
fov=30
|
234 |
-
)
|
235 |
-
|
236 |
-
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
237 |
-
model,
|
238 |
-
planes,
|
239 |
-
render_cameras=all_mvp,
|
240 |
-
camera_pos=all_campos,
|
241 |
-
env=ENV,
|
242 |
-
materials=materials,
|
243 |
-
render_size=render_size,
|
244 |
-
chunk_size=20,
|
245 |
-
is_flexicubes=True,
|
246 |
-
)
|
247 |
-
normals = (torch.nn.functional.normalize(normals) + 1) / 2
|
248 |
-
normals = normals * alphas + (1-alphas)
|
249 |
-
all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
|
250 |
-
|
251 |
-
# breakpoint()
|
252 |
-
save_video(
|
253 |
-
all_frames,
|
254 |
-
video_path_idx,
|
255 |
-
fps=30,
|
256 |
-
)
|
257 |
-
print(f"Video saved to {video_path_idx}")
|
258 |
-
|
259 |
-
if render_azimuths is not None and render_elevations is not None and render_radius is not None:
|
260 |
-
render_size = infer_config.render_resolution
|
261 |
-
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
262 |
-
materials = (0.0,0.9)
|
263 |
-
all_mv, all_mvp, all_campos, identity_mv = get_render_cameras_frames(
|
264 |
-
batch_size=1,
|
265 |
-
radius=render_radius,
|
266 |
-
azimuths=render_azimuths,
|
267 |
-
elevations=render_elevations,
|
268 |
-
fov=30
|
269 |
-
)
|
270 |
-
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
271 |
-
model,
|
272 |
-
planes,
|
273 |
-
render_cameras=all_mvp,
|
274 |
-
camera_pos=all_campos,
|
275 |
-
env=ENV,
|
276 |
-
materials=materials,
|
277 |
-
render_size=render_size,
|
278 |
-
render_mv = all_mv,
|
279 |
-
local_normal=True,
|
280 |
-
identity_mv=identity_mv,
|
281 |
-
)
|
282 |
-
else:
|
283 |
-
normals = None
|
284 |
-
frames = None
|
285 |
-
albedos = None
|
286 |
-
|
287 |
-
return vertices, faces, normals, frames, albedos
|
288 |
-
|
289 |
-
|
290 |
-
def transform_normal(input_normal, azimuths_deg, elevations_deg, radius=4.5, is_global_to_local=False):
|
291 |
-
"""
|
292 |
-
input_normal: in range [-1, 1], shape (b c h w)
|
293 |
-
"""
|
294 |
-
|
295 |
-
input_normal = input_normal.permute(0, 2, 3, 1).cpu()
|
296 |
-
|
297 |
-
azimuths_deg = np.array(azimuths_deg)
|
298 |
-
elevations_deg = np.array(elevations_deg)
|
299 |
-
|
300 |
-
if is_global_to_local:
|
301 |
-
local_normal = normal_transfer.trans_global_2_local(input_normal, azimuths_deg, elevations_deg)
|
302 |
-
return local_normal.permute(0, 3, 1, 2)
|
303 |
-
else:
|
304 |
-
global_normal = normal_transfer.trans_local_2_global(input_normal, azimuths_deg, elevations_deg, radius=radius, for_lotus=False)
|
305 |
-
global_normal[..., 0] *= -1
|
306 |
-
return global_normal.permute(0, 3, 1, 2)
|
307 |
-
|
308 |
-
def local_normal_global_transform(local_normal_images,azimuths_deg,elevations_deg):
|
309 |
-
if local_normal_images.min() >= 0:
|
310 |
-
local_normal = local_normal_images.float() * 2 - 1
|
311 |
-
else:
|
312 |
-
local_normal = local_normal_images.float()
|
313 |
-
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
|
314 |
-
global_normal[...,0] *= -1
|
315 |
-
global_normal = (global_normal + 1) / 2
|
316 |
-
global_normal = global_normal.permute(0, 3, 1, 2)
|
317 |
-
return global_normal
|
318 |
-
|
319 |
-
def main():
|
320 |
-
image_pth = "examples/蓝色小怪物.webp"
|
321 |
-
save_dir_path = os.path.join(save_dir, image_pth.split("/")[-1].split(".")[0])
|
322 |
-
os.makedirs(save_dir_path, exist_ok=True)
|
323 |
-
input_image = Image.open(image_pth)
|
324 |
-
# if not args.no_rembg:
|
325 |
-
input_image = remove_background(input_image, rembg_session)
|
326 |
-
input_image = resize_foreground(input_image, 0.85)
|
327 |
-
|
328 |
-
# generate caption
|
329 |
-
image_caption = run_captioning(image_pth)
|
330 |
-
|
331 |
-
# generate multi-view images
|
332 |
-
output_image = multi_view_rgb_generation(input_image)
|
333 |
-
|
334 |
-
# lrm reconstructions
|
335 |
-
rgb_multi_view = np.asarray(output_image, dtype=np.float32) / 255.0
|
336 |
-
rgb_multi_view = torch.from_numpy(rgb_multi_view).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
|
337 |
-
rgb_multi_view = rearrange(rgb_multi_view, 'c (n h) (m w) -> (n m) c h w', n=2, m=2) # (8, 3, 512, 512)
|
338 |
-
|
339 |
-
input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
|
340 |
-
|
341 |
-
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \
|
342 |
-
lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm',
|
343 |
-
export_texmap=False, if_save_video=False, render_azimuths=isomer_azimuths,
|
344 |
-
render_elevations=isomer_elevations, render_radius=isomer_radius, render_fov=30)
|
345 |
-
|
346 |
-
vertices = torch.from_numpy(vertices).to(device)
|
347 |
-
faces = torch.from_numpy(faces).to(device)
|
348 |
-
vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
|
349 |
-
vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
|
350 |
-
|
351 |
-
|
352 |
-
# lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
|
353 |
-
lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([rgb_multi_view[[3,0,1,2]].cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
|
354 |
-
# rgb_multi_view[[3,0,1,2]] : (B,3,H,W)
|
355 |
-
# lrm_multi_view_normals : (B,3,H,W)
|
356 |
-
# combined_images = 0.5 * rgb_multi_view[[3,0,1,2]].cpu() + 0.5 * (lrm_multi_view_normals.cpu() + 1) / 2
|
357 |
-
# torchvision.utils.save_image(combined_images, os.path.join("debug_output", 'combined.png'))
|
358 |
-
# breakpoint()
|
359 |
-
# Use the low-quality controlnet by default, feel free to try the others
|
360 |
-
control_image = [lrm_3D_bundle_image * 2 - 1]
|
361 |
-
control_mode = ['tile']
|
362 |
-
control_guidance_start = [0.0]
|
363 |
-
control_guidance_end = [0.3]
|
364 |
-
controlnet_conditioning_scale = [0.8]
|
365 |
-
|
366 |
-
flux_pipe.controlnet = FluxMultiControlNetModel([flux_controlnet for _ in control_mode])
|
367 |
-
# breakpoint()
|
368 |
-
rgb_normal_grid = multi_view_rgb_normal_generation_with_controlnet(
|
369 |
-
prompt= ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', image_caption]),
|
370 |
-
image=lrm_3D_bundle_image,
|
371 |
-
strength=0.6,
|
372 |
-
control_image=control_image,
|
373 |
-
control_mode=control_mode,
|
374 |
-
control_guidance_start=control_guidance_start,
|
375 |
-
control_guidance_end=control_guidance_end,
|
376 |
-
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
377 |
-
lora_scale=1.0
|
378 |
-
) # noted that rgb_normal_grid is a (b, h, w, c) numpy array
|
379 |
-
|
380 |
-
rgb_normal_grid = torch.from_numpy(rgb_normal_grid).contiguous().float()
|
381 |
-
rgb_normal_grid = rearrange(rgb_normal_grid.squeeze(0), '(n h) (m w) c-> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
|
382 |
-
rgb_multi_view = rgb_normal_grid[:4, :3, :, :].cuda()
|
383 |
-
normal_multi_view = rgb_normal_grid[4:, :3, :, :].cuda()
|
384 |
-
multi_view_mask = get_background(normal_multi_view).cuda()
|
385 |
-
rgb_multi_view = rgb_multi_view * multi_view_mask + (1-multi_view_mask)
|
386 |
-
|
387 |
-
# local normal to global normal
|
388 |
-
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1).cpu(), isomer_azimuths, isomer_elevations).cuda()
|
389 |
-
|
390 |
-
global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
|
391 |
-
|
392 |
-
global_normal = global_normal.permute(0,2,3,1)
|
393 |
-
multi_view_mask = multi_view_mask.squeeze(1)
|
394 |
-
rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
|
395 |
-
# global_normal: B,H,W,3
|
396 |
-
# multi_view_mask: B,H,W
|
397 |
-
# rgb_multi_view: B,H,W,3
|
398 |
-
|
399 |
-
|
400 |
-
meshes = reconstruction(
|
401 |
-
normal_pils=global_normal,
|
402 |
-
masks=multi_view_mask,
|
403 |
-
weights=isomer_geo_weights,
|
404 |
-
fov=30,
|
405 |
-
radius=isomer_radius,
|
406 |
-
camera_angles_azi=isomer_azimuths,
|
407 |
-
camera_angles_ele=isomer_elevations,
|
408 |
-
expansion_weight_stage1=0.1,
|
409 |
-
init_type="file",
|
410 |
-
init_verts=vertices,
|
411 |
-
init_faces=faces,
|
412 |
-
stage1_steps=0,
|
413 |
-
stage2_steps=50,
|
414 |
-
start_edge_len_stage1=0.1,
|
415 |
-
end_edge_len_stage1=0.02,
|
416 |
-
start_edge_len_stage2=0.02,
|
417 |
-
end_edge_len_stage2=0.005,
|
418 |
-
)
|
419 |
-
|
420 |
-
save_glb_addr = projection(
|
421 |
-
meshes=meshes,
|
422 |
-
masks=multi_view_mask,
|
423 |
-
images=rgb_multi_view,
|
424 |
-
azimuths=isomer_azimuths,
|
425 |
-
elevations=isomer_elevations,
|
426 |
-
weights=isomer_color_weights,
|
427 |
-
fov=30,
|
428 |
-
radius=isomer_radius,
|
429 |
-
save_dir=f"{save_dir_path}/ISOMER/",
|
430 |
-
)
|
431 |
-
print(f'saved to {save_glb_addr}')
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
if __name__ == '__main__':
|
436 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
live_preview_helpers.py
DELETED
@@ -1,167 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import numpy as np
|
3 |
-
from diffusers import FluxPipeline, AutoencoderTiny, FlowMatchEulerDiscreteScheduler
|
4 |
-
from typing import Any, Dict, List, Optional, Union
|
5 |
-
|
6 |
-
# Helper functions
|
7 |
-
def calculate_shift(
|
8 |
-
image_seq_len,
|
9 |
-
base_seq_len: int = 256,
|
10 |
-
max_seq_len: int = 4096,
|
11 |
-
base_shift: float = 0.5,
|
12 |
-
max_shift: float = 1.16,
|
13 |
-
):
|
14 |
-
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
15 |
-
b = base_shift - m * base_seq_len
|
16 |
-
mu = image_seq_len * m + b
|
17 |
-
return mu
|
18 |
-
|
19 |
-
def retrieve_timesteps(
|
20 |
-
scheduler,
|
21 |
-
num_inference_steps: Optional[int] = None,
|
22 |
-
device: Optional[Union[str, torch.device]] = None,
|
23 |
-
timesteps: Optional[List[int]] = None,
|
24 |
-
sigmas: Optional[List[float]] = None,
|
25 |
-
**kwargs,
|
26 |
-
):
|
27 |
-
if timesteps is not None and sigmas is not None:
|
28 |
-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
29 |
-
if timesteps is not None:
|
30 |
-
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
31 |
-
timesteps = scheduler.timesteps
|
32 |
-
num_inference_steps = len(timesteps)
|
33 |
-
elif sigmas is not None:
|
34 |
-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
35 |
-
timesteps = scheduler.timesteps
|
36 |
-
num_inference_steps = len(timesteps)
|
37 |
-
else:
|
38 |
-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
39 |
-
timesteps = scheduler.timesteps
|
40 |
-
return timesteps, num_inference_steps
|
41 |
-
|
42 |
-
# FLUX pipeline function
|
43 |
-
@torch.inference_mode()
|
44 |
-
def flux_pipe_call_that_returns_an_iterable_of_images(
|
45 |
-
self,
|
46 |
-
prompt: Union[str, List[str]] = None,
|
47 |
-
prompt_2: Optional[Union[str, List[str]]] = None,
|
48 |
-
height: Optional[int] = None,
|
49 |
-
width: Optional[int] = None,
|
50 |
-
num_inference_steps: int = 28,
|
51 |
-
timesteps: List[int] = None,
|
52 |
-
guidance_scale: float = 3.5,
|
53 |
-
num_images_per_prompt: Optional[int] = 1,
|
54 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
55 |
-
latents: Optional[torch.FloatTensor] = None,
|
56 |
-
prompt_embeds: Optional[torch.FloatTensor] = None,
|
57 |
-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
58 |
-
output_type: Optional[str] = "pil",
|
59 |
-
return_dict: bool = True,
|
60 |
-
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
61 |
-
max_sequence_length: int = 512,
|
62 |
-
good_vae: Optional[Any] = None,
|
63 |
-
):
|
64 |
-
height = height or self.default_sample_size * self.vae_scale_factor
|
65 |
-
width = width or self.default_sample_size * self.vae_scale_factor
|
66 |
-
|
67 |
-
# 1. Check inputs
|
68 |
-
self.check_inputs(
|
69 |
-
prompt,
|
70 |
-
prompt_2,
|
71 |
-
height,
|
72 |
-
width,
|
73 |
-
prompt_embeds=prompt_embeds,
|
74 |
-
pooled_prompt_embeds=pooled_prompt_embeds,
|
75 |
-
max_sequence_length=max_sequence_length,
|
76 |
-
)
|
77 |
-
|
78 |
-
self._guidance_scale = guidance_scale
|
79 |
-
self._joint_attention_kwargs = joint_attention_kwargs
|
80 |
-
self._interrupt = False
|
81 |
-
|
82 |
-
# 2. Define call parameters
|
83 |
-
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
84 |
-
device = self._execution_device
|
85 |
-
|
86 |
-
# 3. Encode prompt
|
87 |
-
lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
|
88 |
-
prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
|
89 |
-
prompt=prompt,
|
90 |
-
prompt_2=prompt_2,
|
91 |
-
prompt_embeds=prompt_embeds,
|
92 |
-
pooled_prompt_embeds=pooled_prompt_embeds,
|
93 |
-
device=device,
|
94 |
-
num_images_per_prompt=num_images_per_prompt,
|
95 |
-
max_sequence_length=max_sequence_length,
|
96 |
-
lora_scale=lora_scale,
|
97 |
-
)
|
98 |
-
# 4. Prepare latent variables
|
99 |
-
num_channels_latents = self.transformer.config.in_channels // 4
|
100 |
-
latents, latent_image_ids = self.prepare_latents(
|
101 |
-
batch_size * num_images_per_prompt,
|
102 |
-
num_channels_latents,
|
103 |
-
height,
|
104 |
-
width,
|
105 |
-
prompt_embeds.dtype,
|
106 |
-
device,
|
107 |
-
generator,
|
108 |
-
latents,
|
109 |
-
)
|
110 |
-
# 5. Prepare timesteps
|
111 |
-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
112 |
-
image_seq_len = latents.shape[1]
|
113 |
-
mu = calculate_shift(
|
114 |
-
image_seq_len,
|
115 |
-
self.scheduler.config.base_image_seq_len,
|
116 |
-
self.scheduler.config.max_image_seq_len,
|
117 |
-
self.scheduler.config.base_shift,
|
118 |
-
self.scheduler.config.max_shift,
|
119 |
-
)
|
120 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
121 |
-
self.scheduler,
|
122 |
-
num_inference_steps,
|
123 |
-
device,
|
124 |
-
timesteps,
|
125 |
-
sigmas,
|
126 |
-
mu=mu,
|
127 |
-
)
|
128 |
-
self._num_timesteps = len(timesteps)
|
129 |
-
|
130 |
-
# Handle guidance
|
131 |
-
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
|
132 |
-
|
133 |
-
# 6. Denoising loop
|
134 |
-
for i, t in enumerate(timesteps):
|
135 |
-
print(f"Inference step {i+1}/{num_inference_steps}")
|
136 |
-
if self.interrupt:
|
137 |
-
continue
|
138 |
-
|
139 |
-
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
140 |
-
|
141 |
-
noise_pred = self.transformer(
|
142 |
-
hidden_states=latents,
|
143 |
-
timestep=timestep / 1000,
|
144 |
-
guidance=guidance,
|
145 |
-
pooled_projections=pooled_prompt_embeds,
|
146 |
-
encoder_hidden_states=prompt_embeds,
|
147 |
-
txt_ids=text_ids,
|
148 |
-
img_ids=latent_image_ids,
|
149 |
-
joint_attention_kwargs=self.joint_attention_kwargs,
|
150 |
-
return_dict=False,
|
151 |
-
)[0]
|
152 |
-
# Yield intermediate result
|
153 |
-
latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
154 |
-
latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
155 |
-
image = self.vae.decode(latents_for_image, return_dict=False)[0]
|
156 |
-
yield self.image_processor.postprocess(image, output_type=output_type)[0]
|
157 |
-
|
158 |
-
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
159 |
-
torch.cuda.empty_cache()
|
160 |
-
|
161 |
-
# Final image using good_vae
|
162 |
-
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
163 |
-
latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
|
164 |
-
image = good_vae.decode(latents, return_dict=False)[0]
|
165 |
-
self.maybe_free_model_hooks()
|
166 |
-
torch.cuda.empty_cache()
|
167 |
-
yield self.image_processor.postprocess(image, output_type=output_type)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/llm/__pycache__/llm.cpython-310.pyc
ADDED
Binary file (5.41 kB). View file
|
|
models/llm/llm.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
import torch
|
4 |
+
# device = "cuda" # the device to load the model onto
|
5 |
+
model_name_or_dir="Qwen/Qwen2-7B-Instruct"
|
6 |
+
|
7 |
+
|
8 |
+
DEFAULT_SYSTEM_PROMPT = """*Given the user's input describing an object, concept, or vague idea, generate a concise and vivid prompt for the diffusion model that portrays a 3D object based solely on that input. Do not include scenes or backgrounds. The prompt should include specific descriptions for each of the four views—front, left side, rear, and right side—that will be displayed in a 2x4 grid (RGB images on the top row and normal maps on the bottom row). Put all descriptions in one single line. Focus on enhancing the cuteness and 3D qualities of the object without including any background or scene elements. Use descriptive adjectives and, if appropriate, stylistic elements to amplify the object's appeal.*
|
9 |
+
|
10 |
+
---
|
11 |
+
|
12 |
+
**Examples: (Please follow the OUTPUT Format of the following examples.)**
|
13 |
+
|
14 |
+
- **User's Input:** "我喜欢蘑菇."
|
15 |
+
A charming 3D mushroom character with a cheerful expression and blushing cheeks, styled in a whimsical, cartoonish manner. Front view displays a wide, happy smile, round eyes, and a polka-dotted cap with a small ladybug perched on top; left side view reveals a miniature satchel with a tiny acorn charm hanging from its stem; rear view shows a cute, tiny backpack decorated with mushroom patterns and a small patch of grass at the base; right side view features a petite, colorful umbrella tucked under its cap, with a ladybug sitting on the handle. No background. Arrange in a 2x4 grid with RGB images on top and normal maps below.
|
16 |
+
|
17 |
+
- **User's Input:** "画点关于太空的东西吧."
|
18 |
+
A delightful 3D astronaut plush toy with oversized, twinkling eyes and a tiny, shiny helmet, styled in an endearing, kawaii fashion. Front view showcases a joyful smile, a sparkly visor, and a round emblem with a star on the chest; left side view highlights a small flag patch on the arm, with a tiny rocket embroidery; rear view reveals a heart-shaped mini oxygen tank with a playful bow attached; right side view displays a waving hand adorned with tiny, glittering stars and a wristband with planets. No background. Display in a 2x4 grid, top row RGB images, bottom row normal maps.
|
19 |
+
|
20 |
+
- **User's Input:** "老哥,画条龙?"
|
21 |
+
A tiny, chubby 3D dragon with a joyful expression and dainty wings, styled in a cute, fantasy-inspired manner. Front view presents large, sparkling eyes, small curved horns, and a toothy grin; left side view features a little pouch hanging from its neck with a golden coin peeking out; rear view reveals a heart-shaped tail adorned with small, shimmering scales; right side view displays a miniature shield with a dragon emblem, and a wing folded in a playful manner. No background. Presented in a 2x4 grid with RGB images above and normal maps below.
|
22 |
+
|
23 |
+
- **User's Input:** "Maybe a robot?"
|
24 |
+
A lovable 3D robot with a round, friendly body and an inviting smile, styled in a sleek, minimalist design. Front view shows glowing, expressive eyes, a cheerful mouth, and a touch-screen panel with a smiley face; left side view highlights a side antenna with a blinking light and a small digital clock display; rear view reveals a charming power pack with colorful circuits and a sticker of a smiling sun; right side view features a mechanical arm holding a tiny flower with a ladybug perched on a petal. No scene elements. Organize in a 2x4 grid, RGB images on the top row, normal maps on the bottom row.
|
25 |
+
|
26 |
+
---
|
27 |
+
|
28 |
+
**Tips:**
|
29 |
+
|
30 |
+
- **Use Stylized Descriptions:** Mention styles that enhance cuteness (e.g., chibi, kawaii, cartoonish).
|
31 |
+
|
32 |
+
- **Incorporate Expressive Features:** Emphasize features like big eyes, smiles, or playful accessories.
|
33 |
+
|
34 |
+
- **Tailor View-Specific Details:** Ensure each view adds unique details to enrich the object's visual appeal.
|
35 |
+
|
36 |
+
- **Avoid Ambiguity:** Make sure the prompt is specific enough for the model to interpret accurately but doesn't include unnecessary information.
|
37 |
+
|
38 |
+
OUTPUT THE PROMPT ONLY!
|
39 |
+
OUTPUT ENGLISH ONLY! NOT ANY OTHER LANGUAGE, E.G., CHINESE!"""
|
40 |
+
|
41 |
+
def load_llm_model(model_name_or_dir, torch_dtype='auto', device_map='cpu'):
|
42 |
+
model = AutoModelForCausalLM.from_pretrained(
|
43 |
+
model_name_or_dir,
|
44 |
+
torch_dtype=torch_dtype,
|
45 |
+
# torch_dtype=torch.float8_e5m2,
|
46 |
+
# torch_dtype=torch.float16,
|
47 |
+
device_map=device_map
|
48 |
+
)
|
49 |
+
print(f'set llm model to {model_name_or_dir}')
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_dir)
|
51 |
+
print(f'set llm tokenizer to {model_name_or_dir}')
|
52 |
+
return model, tokenizer
|
53 |
+
|
54 |
+
|
55 |
+
# print(f"Before load llm model: {torch.cuda.memory_allocated() / 1024**3} GB")
|
56 |
+
# load_model()
|
57 |
+
# print(f"After load llm model: {torch.cuda.memory_allocated() / 1024**3} GB")
|
58 |
+
|
59 |
+
def get_llm_response(model, tokenizer, user_prompt, seed=None, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
60 |
+
# global model
|
61 |
+
# global tokenizer
|
62 |
+
# load_model()
|
63 |
+
|
64 |
+
messages = [
|
65 |
+
{"role": "system", "content": system_prompt},
|
66 |
+
{"role": "user", "content": user_prompt}
|
67 |
+
]
|
68 |
+
text = tokenizer.apply_chat_template(
|
69 |
+
messages,
|
70 |
+
tokenize=False,
|
71 |
+
add_generation_prompt=True
|
72 |
+
)
|
73 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
74 |
+
|
75 |
+
if seed is not None:
|
76 |
+
torch.manual_seed(seed)
|
77 |
+
|
78 |
+
# breakpoint()
|
79 |
+
generated_ids = model.generate(
|
80 |
+
model_inputs.input_ids,
|
81 |
+
max_new_tokens=512,
|
82 |
+
temperature=0.7,
|
83 |
+
)
|
84 |
+
generated_ids = [
|
85 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
86 |
+
]
|
87 |
+
|
88 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
89 |
+
|
90 |
+
return response
|
91 |
+
|
92 |
+
# if __name__ == "__main__":
|
93 |
+
|
94 |
+
# user_prompt="哈利波特"
|
95 |
+
# rsp = get_response(user_prompt, seed=0)
|
96 |
+
# print(rsp)
|
97 |
+
# breakpoint()
|
pipeline/custom_pipelines/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
|
2 |
+
from .pipeline_flux_img2img import FluxImg2ImgPipeline
|
3 |
+
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
|
pipeline/custom_pipelines/pipeline_flux_controlnet_image_to_image.py
ADDED
@@ -0,0 +1,1004 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copied from diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from transformers import (
|
9 |
+
CLIPTextModel,
|
10 |
+
CLIPTokenizer,
|
11 |
+
T5EncoderModel,
|
12 |
+
T5TokenizerFast,
|
13 |
+
)
|
14 |
+
|
15 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
16 |
+
from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
17 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
18 |
+
from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
19 |
+
from diffusers.models.transformers import FluxTransformer2DModel
|
20 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
21 |
+
from diffusers.utils import (
|
22 |
+
USE_PEFT_BACKEND,
|
23 |
+
is_torch_xla_available,
|
24 |
+
logging,
|
25 |
+
replace_example_docstring,
|
26 |
+
scale_lora_layers,
|
27 |
+
unscale_lora_layers,
|
28 |
+
)
|
29 |
+
from diffusers.utils.torch_utils import randn_tensor
|
30 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
31 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
32 |
+
|
33 |
+
|
34 |
+
if is_torch_xla_available():
|
35 |
+
import torch_xla.core.xla_model as xm
|
36 |
+
|
37 |
+
XLA_AVAILABLE = True
|
38 |
+
else:
|
39 |
+
XLA_AVAILABLE = False
|
40 |
+
|
41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
42 |
+
|
43 |
+
EXAMPLE_DOC_STRING = """
|
44 |
+
Examples:
|
45 |
+
```py
|
46 |
+
>>> import torch
|
47 |
+
>>> from diffusers import FluxControlNetImg2ImgPipeline, FluxControlNetModel
|
48 |
+
>>> from diffusers.utils import load_image
|
49 |
+
|
50 |
+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
51 |
+
|
52 |
+
>>> controlnet = FluxControlNetModel.from_pretrained(
|
53 |
+
... "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16
|
54 |
+
... )
|
55 |
+
|
56 |
+
>>> pipe = FluxControlNetImg2ImgPipeline.from_pretrained(
|
57 |
+
... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16
|
58 |
+
... )
|
59 |
+
|
60 |
+
>>> pipe.text_encoder.to(torch.float16)
|
61 |
+
>>> pipe.controlnet.to(torch.float16)
|
62 |
+
>>> pipe.to("cuda")
|
63 |
+
|
64 |
+
>>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
|
65 |
+
>>> init_image = load_image(
|
66 |
+
... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
67 |
+
... )
|
68 |
+
|
69 |
+
>>> prompt = "A girl in city, 25 years old, cool, futuristic"
|
70 |
+
>>> image = pipe(
|
71 |
+
... prompt,
|
72 |
+
... image=init_image,
|
73 |
+
... control_image=control_image,
|
74 |
+
... control_guidance_start=0.2,
|
75 |
+
... control_guidance_end=0.8,
|
76 |
+
... controlnet_conditioning_scale=1.0,
|
77 |
+
... strength=0.7,
|
78 |
+
... num_inference_steps=2,
|
79 |
+
... guidance_scale=3.5,
|
80 |
+
... ).images[0]
|
81 |
+
>>> image.save("flux_controlnet_img2img.png")
|
82 |
+
```
|
83 |
+
"""
|
84 |
+
|
85 |
+
|
86 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
87 |
+
def calculate_shift(
|
88 |
+
image_seq_len,
|
89 |
+
base_seq_len: int = 256,
|
90 |
+
max_seq_len: int = 4096,
|
91 |
+
base_shift: float = 0.5,
|
92 |
+
max_shift: float = 1.16,
|
93 |
+
):
|
94 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
95 |
+
b = base_shift - m * base_seq_len
|
96 |
+
mu = image_seq_len * m + b
|
97 |
+
return mu
|
98 |
+
|
99 |
+
|
100 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
101 |
+
def retrieve_latents(
|
102 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
103 |
+
):
|
104 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
105 |
+
return encoder_output.latent_dist.sample(generator)
|
106 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
107 |
+
return encoder_output.latent_dist.mode()
|
108 |
+
elif hasattr(encoder_output, "latents"):
|
109 |
+
return encoder_output.latents
|
110 |
+
else:
|
111 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
112 |
+
|
113 |
+
|
114 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
115 |
+
def retrieve_timesteps(
|
116 |
+
scheduler,
|
117 |
+
num_inference_steps: Optional[int] = None,
|
118 |
+
device: Optional[Union[str, torch.device]] = None,
|
119 |
+
timesteps: Optional[List[int]] = None,
|
120 |
+
sigmas: Optional[List[float]] = None,
|
121 |
+
**kwargs,
|
122 |
+
):
|
123 |
+
r"""
|
124 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
125 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
scheduler (`SchedulerMixin`):
|
129 |
+
The scheduler to get timesteps from.
|
130 |
+
num_inference_steps (`int`):
|
131 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
132 |
+
must be `None`.
|
133 |
+
device (`str` or `torch.device`, *optional*):
|
134 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
135 |
+
timesteps (`List[int]`, *optional*):
|
136 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
137 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
138 |
+
sigmas (`List[float]`, *optional*):
|
139 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
140 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
144 |
+
second element is the number of inference steps.
|
145 |
+
"""
|
146 |
+
if timesteps is not None and sigmas is not None:
|
147 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
148 |
+
if timesteps is not None:
|
149 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
150 |
+
if not accepts_timesteps:
|
151 |
+
raise ValueError(
|
152 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
153 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
154 |
+
)
|
155 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
156 |
+
timesteps = scheduler.timesteps
|
157 |
+
num_inference_steps = len(timesteps)
|
158 |
+
else:
|
159 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
160 |
+
if accept_sigmas:
|
161 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
162 |
+
# if not accept_sigmas:
|
163 |
+
# raise ValueError(
|
164 |
+
# f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
165 |
+
# f" sigmas schedules. Please check whether you are using the correct scheduler."
|
166 |
+
# )
|
167 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
168 |
+
timesteps = scheduler.timesteps
|
169 |
+
num_inference_steps = len(timesteps)
|
170 |
+
else:
|
171 |
+
scheduler.set_timesteps(num_inference_steps, device=device)#, **kwargs)
|
172 |
+
timesteps = scheduler.timesteps
|
173 |
+
|
174 |
+
return timesteps, num_inference_steps
|
175 |
+
|
176 |
+
|
177 |
+
class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
178 |
+
r"""
|
179 |
+
The Flux controlnet pipeline for image-to-image generation.
|
180 |
+
|
181 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
182 |
+
|
183 |
+
Args:
|
184 |
+
transformer ([`FluxTransformer2DModel`]):
|
185 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
186 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
187 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
188 |
+
vae ([`AutoencoderKL`]):
|
189 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
190 |
+
text_encoder ([`CLIPTextModel`]):
|
191 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
192 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
193 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
194 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
195 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
196 |
+
tokenizer (`CLIPTokenizer`):
|
197 |
+
Tokenizer of class
|
198 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
199 |
+
tokenizer_2 (`T5TokenizerFast`):
|
200 |
+
Second Tokenizer of class
|
201 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
202 |
+
"""
|
203 |
+
|
204 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
205 |
+
_optional_components = []
|
206 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
207 |
+
|
208 |
+
def __init__(
|
209 |
+
self,
|
210 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
211 |
+
vae: AutoencoderKL,
|
212 |
+
text_encoder: CLIPTextModel,
|
213 |
+
tokenizer: CLIPTokenizer,
|
214 |
+
text_encoder_2: T5EncoderModel,
|
215 |
+
tokenizer_2: T5TokenizerFast,
|
216 |
+
transformer: FluxTransformer2DModel,
|
217 |
+
controlnet: Union[
|
218 |
+
FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
|
219 |
+
],
|
220 |
+
):
|
221 |
+
super().__init__()
|
222 |
+
if isinstance(controlnet, (list, tuple)):
|
223 |
+
controlnet = FluxMultiControlNetModel(controlnet)
|
224 |
+
|
225 |
+
self.register_modules(
|
226 |
+
vae=vae,
|
227 |
+
text_encoder=text_encoder,
|
228 |
+
text_encoder_2=text_encoder_2,
|
229 |
+
tokenizer=tokenizer,
|
230 |
+
tokenizer_2=tokenizer_2,
|
231 |
+
transformer=transformer,
|
232 |
+
scheduler=scheduler,
|
233 |
+
controlnet=controlnet,
|
234 |
+
)
|
235 |
+
self.vae_scale_factor = (
|
236 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
237 |
+
)
|
238 |
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
239 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
240 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
241 |
+
self.tokenizer_max_length = (
|
242 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
243 |
+
)
|
244 |
+
self.default_sample_size = 128
|
245 |
+
|
246 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
247 |
+
def _get_t5_prompt_embeds(
|
248 |
+
self,
|
249 |
+
prompt: Union[str, List[str]] = None,
|
250 |
+
num_images_per_prompt: int = 1,
|
251 |
+
max_sequence_length: int = 512,
|
252 |
+
device: Optional[torch.device] = None,
|
253 |
+
dtype: Optional[torch.dtype] = None,
|
254 |
+
):
|
255 |
+
device = device or self._execution_device
|
256 |
+
dtype = dtype or self.text_encoder.dtype
|
257 |
+
|
258 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
259 |
+
batch_size = len(prompt)
|
260 |
+
|
261 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
262 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
263 |
+
|
264 |
+
text_inputs = self.tokenizer_2(
|
265 |
+
prompt,
|
266 |
+
padding="max_length",
|
267 |
+
max_length=max_sequence_length,
|
268 |
+
truncation=True,
|
269 |
+
return_length=False,
|
270 |
+
return_overflowing_tokens=False,
|
271 |
+
return_tensors="pt",
|
272 |
+
)
|
273 |
+
text_input_ids = text_inputs.input_ids
|
274 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
275 |
+
|
276 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
277 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
278 |
+
logger.warning(
|
279 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
280 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
281 |
+
)
|
282 |
+
|
283 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
284 |
+
|
285 |
+
dtype = self.text_encoder_2.dtype
|
286 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
287 |
+
|
288 |
+
_, seq_len, _ = prompt_embeds.shape
|
289 |
+
|
290 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
291 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
292 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
293 |
+
|
294 |
+
return prompt_embeds
|
295 |
+
|
296 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
|
297 |
+
def _get_clip_prompt_embeds(
|
298 |
+
self,
|
299 |
+
prompt: Union[str, List[str]],
|
300 |
+
num_images_per_prompt: int = 1,
|
301 |
+
device: Optional[torch.device] = None,
|
302 |
+
):
|
303 |
+
device = device or self._execution_device
|
304 |
+
|
305 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
306 |
+
batch_size = len(prompt)
|
307 |
+
|
308 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
309 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
310 |
+
|
311 |
+
text_inputs = self.tokenizer(
|
312 |
+
prompt,
|
313 |
+
padding="max_length",
|
314 |
+
max_length=self.tokenizer_max_length,
|
315 |
+
truncation=True,
|
316 |
+
return_overflowing_tokens=False,
|
317 |
+
return_length=False,
|
318 |
+
return_tensors="pt",
|
319 |
+
)
|
320 |
+
|
321 |
+
text_input_ids = text_inputs.input_ids
|
322 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
323 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
324 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
325 |
+
logger.warning(
|
326 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
327 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
328 |
+
)
|
329 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
330 |
+
|
331 |
+
# Use pooled output of CLIPTextModel
|
332 |
+
prompt_embeds = prompt_embeds.pooler_output
|
333 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
334 |
+
|
335 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
336 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
337 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
338 |
+
|
339 |
+
return prompt_embeds
|
340 |
+
|
341 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
|
342 |
+
def encode_prompt(
|
343 |
+
self,
|
344 |
+
prompt: Union[str, List[str]],
|
345 |
+
prompt_2: Union[str, List[str]],
|
346 |
+
device: Optional[torch.device] = None,
|
347 |
+
num_images_per_prompt: int = 1,
|
348 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
349 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
350 |
+
max_sequence_length: int = 512,
|
351 |
+
lora_scale: Optional[float] = None,
|
352 |
+
):
|
353 |
+
r"""
|
354 |
+
|
355 |
+
Args:
|
356 |
+
prompt (`str` or `List[str]`, *optional*):
|
357 |
+
prompt to be encoded
|
358 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
359 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
360 |
+
used in all text-encoders
|
361 |
+
device: (`torch.device`):
|
362 |
+
torch device
|
363 |
+
num_images_per_prompt (`int`):
|
364 |
+
number of images that should be generated per prompt
|
365 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
366 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
367 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
368 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
369 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
370 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
371 |
+
lora_scale (`float`, *optional*):
|
372 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
373 |
+
"""
|
374 |
+
device = device or self._execution_device
|
375 |
+
|
376 |
+
# set lora scale so that monkey patched LoRA
|
377 |
+
# function of text encoder can correctly access it
|
378 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
379 |
+
self._lora_scale = lora_scale
|
380 |
+
|
381 |
+
# dynamically adjust the LoRA scale
|
382 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
383 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
384 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
385 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
386 |
+
|
387 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
388 |
+
|
389 |
+
if prompt_embeds is None:
|
390 |
+
prompt_2 = prompt_2 or prompt
|
391 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
392 |
+
|
393 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
394 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
395 |
+
prompt=prompt,
|
396 |
+
device=device,
|
397 |
+
num_images_per_prompt=num_images_per_prompt,
|
398 |
+
)
|
399 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
400 |
+
prompt=prompt_2,
|
401 |
+
num_images_per_prompt=num_images_per_prompt,
|
402 |
+
max_sequence_length=max_sequence_length,
|
403 |
+
device=device,
|
404 |
+
)
|
405 |
+
|
406 |
+
if self.text_encoder is not None:
|
407 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
408 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
409 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
410 |
+
|
411 |
+
if self.text_encoder_2 is not None:
|
412 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
413 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
414 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
415 |
+
|
416 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
417 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
418 |
+
|
419 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
420 |
+
|
421 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
|
422 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
423 |
+
if isinstance(generator, list):
|
424 |
+
image_latents = [
|
425 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
426 |
+
for i in range(image.shape[0])
|
427 |
+
]
|
428 |
+
image_latents = torch.cat(image_latents, dim=0)
|
429 |
+
else:
|
430 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
431 |
+
|
432 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
433 |
+
|
434 |
+
return image_latents
|
435 |
+
|
436 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
437 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
438 |
+
# get the original timestep using init_timestep
|
439 |
+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
440 |
+
|
441 |
+
t_start = int(max(num_inference_steps - init_timestep, 0))
|
442 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
443 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
444 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
445 |
+
|
446 |
+
return timesteps, num_inference_steps - t_start
|
447 |
+
|
448 |
+
def check_inputs(
|
449 |
+
self,
|
450 |
+
prompt,
|
451 |
+
prompt_2,
|
452 |
+
strength,
|
453 |
+
height,
|
454 |
+
width,
|
455 |
+
callback_on_step_end_tensor_inputs,
|
456 |
+
prompt_embeds=None,
|
457 |
+
pooled_prompt_embeds=None,
|
458 |
+
max_sequence_length=None,
|
459 |
+
):
|
460 |
+
if strength < 0 or strength > 1:
|
461 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
462 |
+
|
463 |
+
if height % self.vae_scale_factor * 2 != 0 or width % self.vae_scale_factor * 2 != 0:
|
464 |
+
logger.warning(
|
465 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
466 |
+
)
|
467 |
+
|
468 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
469 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
470 |
+
):
|
471 |
+
raise ValueError(
|
472 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
473 |
+
)
|
474 |
+
|
475 |
+
if prompt is not None and prompt_embeds is not None:
|
476 |
+
raise ValueError(
|
477 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
478 |
+
" only forward one of the two."
|
479 |
+
)
|
480 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
481 |
+
raise ValueError(
|
482 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
483 |
+
" only forward one of the two."
|
484 |
+
)
|
485 |
+
elif prompt is None and prompt_embeds is None:
|
486 |
+
raise ValueError(
|
487 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
488 |
+
)
|
489 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
490 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
491 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
492 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
493 |
+
|
494 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
495 |
+
raise ValueError(
|
496 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
497 |
+
)
|
498 |
+
|
499 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
500 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
501 |
+
|
502 |
+
@staticmethod
|
503 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
504 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
505 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
506 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
507 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
508 |
+
|
509 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
510 |
+
|
511 |
+
latent_image_ids = latent_image_ids.reshape(
|
512 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
513 |
+
)
|
514 |
+
|
515 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
516 |
+
|
517 |
+
@staticmethod
|
518 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
519 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
520 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
521 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
522 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
523 |
+
|
524 |
+
return latents
|
525 |
+
|
526 |
+
@staticmethod
|
527 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
528 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
529 |
+
batch_size, num_patches, channels = latents.shape
|
530 |
+
|
531 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
532 |
+
# latent height and width to be divisible by 2.
|
533 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
534 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
535 |
+
|
536 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
537 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
538 |
+
|
539 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
540 |
+
|
541 |
+
return latents
|
542 |
+
|
543 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
|
544 |
+
def prepare_latents(
|
545 |
+
self,
|
546 |
+
image,
|
547 |
+
timestep,
|
548 |
+
batch_size,
|
549 |
+
num_channels_latents,
|
550 |
+
height,
|
551 |
+
width,
|
552 |
+
dtype,
|
553 |
+
device,
|
554 |
+
generator,
|
555 |
+
latents=None,
|
556 |
+
):
|
557 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
558 |
+
raise ValueError(
|
559 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
560 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
561 |
+
)
|
562 |
+
|
563 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
564 |
+
# latent height and width to be divisible by 2.
|
565 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
566 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
567 |
+
shape = (batch_size, num_channels_latents, height, width)
|
568 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
569 |
+
|
570 |
+
if latents is not None:
|
571 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
572 |
+
|
573 |
+
image = image.to(device=device, dtype=dtype)
|
574 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
575 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
576 |
+
# expand init_latents for batch_size
|
577 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
578 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
579 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
580 |
+
raise ValueError(
|
581 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
582 |
+
)
|
583 |
+
else:
|
584 |
+
image_latents = torch.cat([image_latents], dim=0)
|
585 |
+
|
586 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
587 |
+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
588 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
589 |
+
return latents, latent_image_ids
|
590 |
+
|
591 |
+
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
|
592 |
+
def prepare_image(
|
593 |
+
self,
|
594 |
+
image,
|
595 |
+
width,
|
596 |
+
height,
|
597 |
+
batch_size,
|
598 |
+
num_images_per_prompt,
|
599 |
+
device,
|
600 |
+
dtype,
|
601 |
+
do_classifier_free_guidance=False,
|
602 |
+
guess_mode=False,
|
603 |
+
):
|
604 |
+
if isinstance(image, torch.Tensor):
|
605 |
+
pass
|
606 |
+
else:
|
607 |
+
image = self.image_processor.preprocess(image, height=height, width=width)
|
608 |
+
|
609 |
+
image_batch_size = image.shape[0]
|
610 |
+
|
611 |
+
if image_batch_size == 1:
|
612 |
+
repeat_by = batch_size
|
613 |
+
else:
|
614 |
+
# image batch size is the same as prompt batch size
|
615 |
+
repeat_by = num_images_per_prompt
|
616 |
+
|
617 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
618 |
+
|
619 |
+
image = image.to(device=device, dtype=dtype)
|
620 |
+
|
621 |
+
if do_classifier_free_guidance and not guess_mode:
|
622 |
+
image = torch.cat([image] * 2)
|
623 |
+
|
624 |
+
return image
|
625 |
+
|
626 |
+
@property
|
627 |
+
def guidance_scale(self):
|
628 |
+
return self._guidance_scale
|
629 |
+
|
630 |
+
@property
|
631 |
+
def joint_attention_kwargs(self):
|
632 |
+
return self._joint_attention_kwargs
|
633 |
+
|
634 |
+
@property
|
635 |
+
def num_timesteps(self):
|
636 |
+
return self._num_timesteps
|
637 |
+
|
638 |
+
@property
|
639 |
+
def interrupt(self):
|
640 |
+
return self._interrupt
|
641 |
+
|
642 |
+
@torch.no_grad()
|
643 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
644 |
+
def __call__(
|
645 |
+
self,
|
646 |
+
prompt: Union[str, List[str]] = None,
|
647 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
648 |
+
image: PipelineImageInput = None,
|
649 |
+
control_image: PipelineImageInput = None,
|
650 |
+
height: Optional[int] = None,
|
651 |
+
width: Optional[int] = None,
|
652 |
+
strength: float = 0.6,
|
653 |
+
num_inference_steps: int = 28,
|
654 |
+
sigmas: Optional[List[float]] = None,
|
655 |
+
guidance_scale: float = 7.0,
|
656 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
657 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
658 |
+
control_mode: Optional[Union[int, List[int]]] = None,
|
659 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
660 |
+
num_images_per_prompt: Optional[int] = 1,
|
661 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
662 |
+
latents: Optional[torch.FloatTensor] = None,
|
663 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
664 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
665 |
+
output_type: Optional[str] = "pil",
|
666 |
+
return_dict: bool = True,
|
667 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
668 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
669 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
670 |
+
max_sequence_length: int = 512,
|
671 |
+
):
|
672 |
+
"""
|
673 |
+
Function invoked when calling the pipeline for generation.
|
674 |
+
|
675 |
+
Args:
|
676 |
+
prompt (`str` or `List[str]`, *optional*):
|
677 |
+
The prompt or prompts to guide the image generation.
|
678 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
679 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
|
680 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
681 |
+
The image(s) to modify with the pipeline.
|
682 |
+
control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
683 |
+
The ControlNet input condition. Image to control the generation.
|
684 |
+
height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
|
685 |
+
The height in pixels of the generated image.
|
686 |
+
width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
|
687 |
+
The width in pixels of the generated image.
|
688 |
+
strength (`float`, *optional*, defaults to 0.6):
|
689 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
690 |
+
num_inference_steps (`int`, *optional*, defaults to 28):
|
691 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
692 |
+
expense of slower inference.
|
693 |
+
sigmas (`List[float]`, *optional*):
|
694 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
695 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
696 |
+
will be used.
|
697 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
698 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
699 |
+
control_mode (`int` or `List[int]`, *optional*):
|
700 |
+
The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
|
701 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
702 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
703 |
+
to the residual in the original transformer.
|
704 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
705 |
+
The number of images to generate per prompt.
|
706 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
707 |
+
One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to
|
708 |
+
make generation deterministic.
|
709 |
+
latents (`torch.FloatTensor`, *optional*):
|
710 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
711 |
+
generation. Can be used to tweak the same generation with different prompts.
|
712 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
713 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
714 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
715 |
+
Pre-generated pooled text embeddings.
|
716 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
717 |
+
The output format of the generate image. Choose between `PIL.Image` or `np.array`.
|
718 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
719 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
720 |
+
joint_attention_kwargs (`dict`, *optional*):
|
721 |
+
Additional keyword arguments to be passed to the joint attention mechanism.
|
722 |
+
callback_on_step_end (`Callable`, *optional*):
|
723 |
+
A function that calls at the end of each denoising step during the inference.
|
724 |
+
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
725 |
+
The list of tensor inputs for the `callback_on_step_end` function.
|
726 |
+
max_sequence_length (`int`, *optional*, defaults to 512):
|
727 |
+
The maximum length of the sequence to be generated.
|
728 |
+
|
729 |
+
Examples:
|
730 |
+
|
731 |
+
Returns:
|
732 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
733 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
734 |
+
images.
|
735 |
+
"""
|
736 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
737 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
738 |
+
|
739 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
740 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
741 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
742 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
743 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
744 |
+
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
|
745 |
+
control_guidance_start, control_guidance_end = (
|
746 |
+
mult * [control_guidance_start],
|
747 |
+
mult * [control_guidance_end],
|
748 |
+
)
|
749 |
+
|
750 |
+
self.check_inputs(
|
751 |
+
prompt,
|
752 |
+
prompt_2,
|
753 |
+
strength,
|
754 |
+
height,
|
755 |
+
width,
|
756 |
+
callback_on_step_end_tensor_inputs,
|
757 |
+
prompt_embeds=prompt_embeds,
|
758 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
759 |
+
max_sequence_length=max_sequence_length,
|
760 |
+
)
|
761 |
+
|
762 |
+
self._guidance_scale = guidance_scale
|
763 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
764 |
+
self._interrupt = False
|
765 |
+
|
766 |
+
if prompt is not None and isinstance(prompt, str):
|
767 |
+
batch_size = 1
|
768 |
+
elif prompt is not None and isinstance(prompt, list):
|
769 |
+
batch_size = len(prompt)
|
770 |
+
else:
|
771 |
+
batch_size = prompt_embeds.shape[0]
|
772 |
+
|
773 |
+
device = self._execution_device
|
774 |
+
dtype = self.transformer.dtype
|
775 |
+
|
776 |
+
lora_scale = (
|
777 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
778 |
+
)
|
779 |
+
(
|
780 |
+
prompt_embeds,
|
781 |
+
pooled_prompt_embeds,
|
782 |
+
text_ids,
|
783 |
+
) = self.encode_prompt(
|
784 |
+
prompt=prompt,
|
785 |
+
prompt_2=prompt_2,
|
786 |
+
prompt_embeds=prompt_embeds,
|
787 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
788 |
+
device=device,
|
789 |
+
num_images_per_prompt=num_images_per_prompt,
|
790 |
+
max_sequence_length=max_sequence_length,
|
791 |
+
lora_scale=lora_scale,
|
792 |
+
)
|
793 |
+
|
794 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
795 |
+
init_image = init_image.to(dtype=torch.float32)
|
796 |
+
|
797 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
798 |
+
|
799 |
+
if isinstance(self.controlnet, FluxControlNetModel):
|
800 |
+
control_image = self.prepare_image(
|
801 |
+
image=control_image,
|
802 |
+
width=width,
|
803 |
+
height=height,
|
804 |
+
batch_size=batch_size * num_images_per_prompt,
|
805 |
+
num_images_per_prompt=num_images_per_prompt,
|
806 |
+
device=device,
|
807 |
+
dtype=self.vae.dtype,
|
808 |
+
)
|
809 |
+
height, width = control_image.shape[-2:]
|
810 |
+
|
811 |
+
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
|
812 |
+
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
813 |
+
|
814 |
+
height_control_image, width_control_image = control_image.shape[2:]
|
815 |
+
control_image = self._pack_latents(
|
816 |
+
control_image,
|
817 |
+
batch_size * num_images_per_prompt,
|
818 |
+
num_channels_latents,
|
819 |
+
height_control_image,
|
820 |
+
width_control_image,
|
821 |
+
)
|
822 |
+
|
823 |
+
if control_mode is not None:
|
824 |
+
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
|
825 |
+
control_mode = control_mode.reshape([-1, 1])
|
826 |
+
|
827 |
+
elif isinstance(self.controlnet, FluxMultiControlNetModel):
|
828 |
+
control_images = []
|
829 |
+
|
830 |
+
for control_image_ in control_image:
|
831 |
+
control_image_ = self.prepare_image(
|
832 |
+
image=control_image_,
|
833 |
+
width=width,
|
834 |
+
height=height,
|
835 |
+
batch_size=batch_size * num_images_per_prompt,
|
836 |
+
num_images_per_prompt=num_images_per_prompt,
|
837 |
+
device=device,
|
838 |
+
dtype=self.vae.dtype,
|
839 |
+
)
|
840 |
+
height, width = control_image_.shape[-2:]
|
841 |
+
|
842 |
+
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
|
843 |
+
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
844 |
+
|
845 |
+
height_control_image, width_control_image = control_image_.shape[2:]
|
846 |
+
control_image_ = self._pack_latents(
|
847 |
+
control_image_,
|
848 |
+
batch_size * num_images_per_prompt,
|
849 |
+
num_channels_latents,
|
850 |
+
height_control_image,
|
851 |
+
width_control_image,
|
852 |
+
)
|
853 |
+
|
854 |
+
control_images.append(control_image_)
|
855 |
+
|
856 |
+
control_image = control_images
|
857 |
+
|
858 |
+
control_mode_ = []
|
859 |
+
if isinstance(control_mode, list):
|
860 |
+
for cmode in control_mode:
|
861 |
+
if cmode is None:
|
862 |
+
control_mode_.append(-1)
|
863 |
+
else:
|
864 |
+
control_mode_.append(cmode)
|
865 |
+
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
|
866 |
+
control_mode = control_mode.reshape([-1, 1])
|
867 |
+
|
868 |
+
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
|
869 |
+
mu = calculate_shift(
|
870 |
+
image_seq_len,
|
871 |
+
self.scheduler.config.base_image_seq_len,
|
872 |
+
self.scheduler.config.max_image_seq_len,
|
873 |
+
self.scheduler.config.base_shift,
|
874 |
+
self.scheduler.config.max_shift,
|
875 |
+
)
|
876 |
+
|
877 |
+
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
878 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
879 |
+
self.scheduler,
|
880 |
+
num_inference_steps,
|
881 |
+
device,
|
882 |
+
sigmas=sigmas,
|
883 |
+
mu=mu,
|
884 |
+
)
|
885 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
886 |
+
|
887 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
888 |
+
latents, latent_image_ids = self.prepare_latents(
|
889 |
+
init_image,
|
890 |
+
latent_timestep,
|
891 |
+
batch_size * num_images_per_prompt,
|
892 |
+
num_channels_latents,
|
893 |
+
height,
|
894 |
+
width,
|
895 |
+
prompt_embeds.dtype,
|
896 |
+
device,
|
897 |
+
generator,
|
898 |
+
latents,
|
899 |
+
)
|
900 |
+
|
901 |
+
controlnet_keep = []
|
902 |
+
for i in range(len(timesteps)):
|
903 |
+
keeps = [
|
904 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
905 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
906 |
+
]
|
907 |
+
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
|
908 |
+
|
909 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
910 |
+
self._num_timesteps = len(timesteps)
|
911 |
+
|
912 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
913 |
+
for i, t in enumerate(timesteps):
|
914 |
+
if self.interrupt:
|
915 |
+
continue
|
916 |
+
|
917 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
918 |
+
|
919 |
+
if isinstance(self.controlnet, FluxMultiControlNetModel):
|
920 |
+
use_guidance = self.controlnet.nets[0].config.guidance_embeds
|
921 |
+
else:
|
922 |
+
use_guidance = self.controlnet.config.guidance_embeds
|
923 |
+
|
924 |
+
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
|
925 |
+
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
|
926 |
+
|
927 |
+
if isinstance(controlnet_keep[i], list):
|
928 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
929 |
+
else:
|
930 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
931 |
+
if isinstance(controlnet_cond_scale, list):
|
932 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
933 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
934 |
+
|
935 |
+
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
|
936 |
+
hidden_states=latents,
|
937 |
+
controlnet_cond=control_image,
|
938 |
+
controlnet_mode=control_mode,
|
939 |
+
conditioning_scale=cond_scale,
|
940 |
+
timestep=timestep / 1000,
|
941 |
+
guidance=guidance,
|
942 |
+
pooled_projections=pooled_prompt_embeds,
|
943 |
+
encoder_hidden_states=prompt_embeds,
|
944 |
+
txt_ids=text_ids,
|
945 |
+
img_ids=latent_image_ids,
|
946 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
947 |
+
return_dict=False,
|
948 |
+
)
|
949 |
+
|
950 |
+
guidance = (
|
951 |
+
torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
|
952 |
+
)
|
953 |
+
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
|
954 |
+
|
955 |
+
noise_pred = self.transformer(
|
956 |
+
hidden_states=latents,
|
957 |
+
timestep=timestep / 1000,
|
958 |
+
guidance=guidance,
|
959 |
+
pooled_projections=pooled_prompt_embeds,
|
960 |
+
encoder_hidden_states=prompt_embeds,
|
961 |
+
controlnet_block_samples=controlnet_block_samples,
|
962 |
+
controlnet_single_block_samples=controlnet_single_block_samples,
|
963 |
+
txt_ids=text_ids,
|
964 |
+
img_ids=latent_image_ids,
|
965 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
966 |
+
return_dict=False,
|
967 |
+
)[0]
|
968 |
+
|
969 |
+
latents_dtype = latents.dtype
|
970 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
971 |
+
|
972 |
+
if latents.dtype != latents_dtype:
|
973 |
+
if torch.backends.mps.is_available():
|
974 |
+
latents = latents.to(latents_dtype)
|
975 |
+
|
976 |
+
if callback_on_step_end is not None:
|
977 |
+
callback_kwargs = {}
|
978 |
+
for k in callback_on_step_end_tensor_inputs:
|
979 |
+
callback_kwargs[k] = locals()[k]
|
980 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
981 |
+
|
982 |
+
latents = callback_outputs.pop("latents", latents)
|
983 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
984 |
+
|
985 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
986 |
+
progress_bar.update()
|
987 |
+
|
988 |
+
if XLA_AVAILABLE:
|
989 |
+
xm.mark_step()
|
990 |
+
|
991 |
+
if output_type == "latent":
|
992 |
+
image = latents
|
993 |
+
else:
|
994 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
995 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
996 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
997 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
998 |
+
|
999 |
+
self.maybe_free_model_hooks()
|
1000 |
+
|
1001 |
+
if not return_dict:
|
1002 |
+
return (image,)
|
1003 |
+
|
1004 |
+
return FluxPipelineOutput(images=image)
|
pipeline/custom_pipelines/pipeline_flux_img2img.py
ADDED
@@ -0,0 +1,862 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# copied from diffusers/src/diffusers/pipeline/flux/pipeline_flux_img2img.py
|
3 |
+
|
4 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import inspect
|
19 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
24 |
+
|
25 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
26 |
+
from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
27 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
28 |
+
from diffusers.models.transformers import FluxTransformer2DModel
|
29 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
30 |
+
from diffusers.utils import (
|
31 |
+
USE_PEFT_BACKEND,
|
32 |
+
is_torch_xla_available,
|
33 |
+
logging,
|
34 |
+
replace_example_docstring,
|
35 |
+
scale_lora_layers,
|
36 |
+
unscale_lora_layers,
|
37 |
+
)
|
38 |
+
from diffusers.utils.torch_utils import randn_tensor
|
39 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
40 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
41 |
+
|
42 |
+
|
43 |
+
if is_torch_xla_available():
|
44 |
+
import torch_xla.core.xla_model as xm
|
45 |
+
|
46 |
+
XLA_AVAILABLE = True
|
47 |
+
else:
|
48 |
+
XLA_AVAILABLE = False
|
49 |
+
|
50 |
+
|
51 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
52 |
+
|
53 |
+
EXAMPLE_DOC_STRING = """
|
54 |
+
Examples:
|
55 |
+
```py
|
56 |
+
>>> import torch
|
57 |
+
|
58 |
+
>>> from diffusers import FluxImg2ImgPipeline
|
59 |
+
>>> from diffusers.utils import load_image
|
60 |
+
|
61 |
+
>>> device = "cuda"
|
62 |
+
>>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
63 |
+
>>> pipe = pipe.to(device)
|
64 |
+
|
65 |
+
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
66 |
+
>>> init_image = load_image(url).resize((1024, 1024))
|
67 |
+
|
68 |
+
>>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
|
69 |
+
|
70 |
+
>>> images = pipe(
|
71 |
+
... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0
|
72 |
+
... ).images[0]
|
73 |
+
```
|
74 |
+
"""
|
75 |
+
|
76 |
+
|
77 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
78 |
+
def calculate_shift(
|
79 |
+
image_seq_len,
|
80 |
+
base_seq_len: int = 256,
|
81 |
+
max_seq_len: int = 4096,
|
82 |
+
base_shift: float = 0.5,
|
83 |
+
max_shift: float = 1.16,
|
84 |
+
):
|
85 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
86 |
+
b = base_shift - m * base_seq_len
|
87 |
+
mu = image_seq_len * m + b
|
88 |
+
return mu
|
89 |
+
|
90 |
+
|
91 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
92 |
+
def retrieve_latents(
|
93 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
94 |
+
):
|
95 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
96 |
+
return encoder_output.latent_dist.sample(generator)
|
97 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
98 |
+
return encoder_output.latent_dist.mode()
|
99 |
+
elif hasattr(encoder_output, "latents"):
|
100 |
+
return encoder_output.latents
|
101 |
+
else:
|
102 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
103 |
+
|
104 |
+
|
105 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
106 |
+
def retrieve_timesteps(
|
107 |
+
scheduler,
|
108 |
+
num_inference_steps: Optional[int] = None,
|
109 |
+
device: Optional[Union[str, torch.device]] = None,
|
110 |
+
timesteps: Optional[List[int]] = None,
|
111 |
+
sigmas: Optional[List[float]] = None,
|
112 |
+
**kwargs,
|
113 |
+
):
|
114 |
+
r"""
|
115 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
116 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
scheduler (`SchedulerMixin`):
|
120 |
+
The scheduler to get timesteps from.
|
121 |
+
num_inference_steps (`int`):
|
122 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
123 |
+
must be `None`.
|
124 |
+
device (`str` or `torch.device`, *optional*):
|
125 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
126 |
+
timesteps (`List[int]`, *optional*):
|
127 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
128 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
129 |
+
sigmas (`List[float]`, *optional*):
|
130 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
131 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
135 |
+
second element is the number of inference steps.
|
136 |
+
"""
|
137 |
+
if timesteps is not None and sigmas is not None:
|
138 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
139 |
+
if timesteps is not None:
|
140 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
141 |
+
if not accepts_timesteps:
|
142 |
+
raise ValueError(
|
143 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
144 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
145 |
+
)
|
146 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
147 |
+
timesteps = scheduler.timesteps
|
148 |
+
num_inference_steps = len(timesteps)
|
149 |
+
else:
|
150 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
151 |
+
if accept_sigmas:
|
152 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
153 |
+
# if not accept_sigmas:
|
154 |
+
# raise ValueError(
|
155 |
+
# f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
156 |
+
# f" sigmas schedules. Please check whether you are using the correct scheduler."
|
157 |
+
# )
|
158 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
159 |
+
timesteps = scheduler.timesteps
|
160 |
+
num_inference_steps = len(timesteps)
|
161 |
+
else:
|
162 |
+
scheduler.set_timesteps(num_inference_steps, device=device)#, **kwargs)
|
163 |
+
timesteps = scheduler.timesteps
|
164 |
+
|
165 |
+
return timesteps, num_inference_steps
|
166 |
+
|
167 |
+
|
168 |
+
class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
169 |
+
r"""
|
170 |
+
The Flux pipeline for image inpainting.
|
171 |
+
|
172 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
173 |
+
|
174 |
+
Args:
|
175 |
+
transformer ([`FluxTransformer2DModel`]):
|
176 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
177 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
178 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
179 |
+
vae ([`AutoencoderKL`]):
|
180 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
181 |
+
text_encoder ([`CLIPTextModel`]):
|
182 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
183 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
184 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
185 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
186 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
187 |
+
tokenizer (`CLIPTokenizer`):
|
188 |
+
Tokenizer of class
|
189 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
190 |
+
tokenizer_2 (`T5TokenizerFast`):
|
191 |
+
Second Tokenizer of class
|
192 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
193 |
+
"""
|
194 |
+
|
195 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
196 |
+
_optional_components = []
|
197 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
198 |
+
|
199 |
+
def __init__(
|
200 |
+
self,
|
201 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
202 |
+
vae: AutoencoderKL,
|
203 |
+
text_encoder: CLIPTextModel,
|
204 |
+
tokenizer: CLIPTokenizer,
|
205 |
+
text_encoder_2: T5EncoderModel,
|
206 |
+
tokenizer_2: T5TokenizerFast,
|
207 |
+
transformer: FluxTransformer2DModel,
|
208 |
+
):
|
209 |
+
super().__init__()
|
210 |
+
|
211 |
+
self.register_modules(
|
212 |
+
vae=vae,
|
213 |
+
text_encoder=text_encoder,
|
214 |
+
text_encoder_2=text_encoder_2,
|
215 |
+
tokenizer=tokenizer,
|
216 |
+
tokenizer_2=tokenizer_2,
|
217 |
+
transformer=transformer,
|
218 |
+
scheduler=scheduler,
|
219 |
+
)
|
220 |
+
self.vae_scale_factor = (
|
221 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
222 |
+
)
|
223 |
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
224 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
225 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
226 |
+
self.tokenizer_max_length = (
|
227 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
228 |
+
)
|
229 |
+
self.default_sample_size = 128
|
230 |
+
|
231 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
232 |
+
def _get_t5_prompt_embeds(
|
233 |
+
self,
|
234 |
+
prompt: Union[str, List[str]] = None,
|
235 |
+
num_images_per_prompt: int = 1,
|
236 |
+
max_sequence_length: int = 512,
|
237 |
+
device: Optional[torch.device] = None,
|
238 |
+
dtype: Optional[torch.dtype] = None,
|
239 |
+
):
|
240 |
+
device = device or self._execution_device
|
241 |
+
dtype = dtype or self.text_encoder.dtype
|
242 |
+
|
243 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
244 |
+
batch_size = len(prompt)
|
245 |
+
|
246 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
247 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
248 |
+
|
249 |
+
text_inputs = self.tokenizer_2(
|
250 |
+
prompt,
|
251 |
+
padding="max_length",
|
252 |
+
max_length=max_sequence_length,
|
253 |
+
truncation=True,
|
254 |
+
return_length=False,
|
255 |
+
return_overflowing_tokens=False,
|
256 |
+
return_tensors="pt",
|
257 |
+
)
|
258 |
+
text_input_ids = text_inputs.input_ids
|
259 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
260 |
+
|
261 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
262 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
263 |
+
logger.warning(
|
264 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
265 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
266 |
+
)
|
267 |
+
|
268 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
269 |
+
|
270 |
+
dtype = self.text_encoder_2.dtype
|
271 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
272 |
+
|
273 |
+
_, seq_len, _ = prompt_embeds.shape
|
274 |
+
|
275 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
276 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
277 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
278 |
+
|
279 |
+
return prompt_embeds
|
280 |
+
|
281 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
|
282 |
+
def _get_clip_prompt_embeds(
|
283 |
+
self,
|
284 |
+
prompt: Union[str, List[str]],
|
285 |
+
num_images_per_prompt: int = 1,
|
286 |
+
device: Optional[torch.device] = None,
|
287 |
+
):
|
288 |
+
device = device or self._execution_device
|
289 |
+
|
290 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
291 |
+
batch_size = len(prompt)
|
292 |
+
|
293 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
294 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
295 |
+
|
296 |
+
text_inputs = self.tokenizer(
|
297 |
+
prompt,
|
298 |
+
padding="max_length",
|
299 |
+
max_length=self.tokenizer_max_length,
|
300 |
+
truncation=True,
|
301 |
+
return_overflowing_tokens=False,
|
302 |
+
return_length=False,
|
303 |
+
return_tensors="pt",
|
304 |
+
)
|
305 |
+
|
306 |
+
text_input_ids = text_inputs.input_ids
|
307 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
308 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
309 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
310 |
+
logger.warning(
|
311 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
312 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
313 |
+
)
|
314 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
315 |
+
|
316 |
+
# Use pooled output of CLIPTextModel
|
317 |
+
prompt_embeds = prompt_embeds.pooler_output
|
318 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
319 |
+
|
320 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
321 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
322 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
323 |
+
|
324 |
+
return prompt_embeds
|
325 |
+
|
326 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
|
327 |
+
def encode_prompt(
|
328 |
+
self,
|
329 |
+
prompt: Union[str, List[str]],
|
330 |
+
prompt_2: Union[str, List[str]],
|
331 |
+
device: Optional[torch.device] = None,
|
332 |
+
num_images_per_prompt: int = 1,
|
333 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
334 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
335 |
+
max_sequence_length: int = 512,
|
336 |
+
lora_scale: Optional[float] = None,
|
337 |
+
):
|
338 |
+
r"""
|
339 |
+
|
340 |
+
Args:
|
341 |
+
prompt (`str` or `List[str]`, *optional*):
|
342 |
+
prompt to be encoded
|
343 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
344 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
345 |
+
used in all text-encoders
|
346 |
+
device: (`torch.device`):
|
347 |
+
torch device
|
348 |
+
num_images_per_prompt (`int`):
|
349 |
+
number of images that should be generated per prompt
|
350 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
351 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
352 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
353 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
354 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
355 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
356 |
+
lora_scale (`float`, *optional*):
|
357 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
358 |
+
"""
|
359 |
+
device = device or self._execution_device
|
360 |
+
|
361 |
+
# set lora scale so that monkey patched LoRA
|
362 |
+
# function of text encoder can correctly access it
|
363 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
364 |
+
self._lora_scale = lora_scale
|
365 |
+
|
366 |
+
# dynamically adjust the LoRA scale
|
367 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
368 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
369 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
370 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
371 |
+
|
372 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
373 |
+
|
374 |
+
if prompt_embeds is None:
|
375 |
+
prompt_2 = prompt_2 or prompt
|
376 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
377 |
+
|
378 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
379 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
380 |
+
prompt=prompt,
|
381 |
+
device=device,
|
382 |
+
num_images_per_prompt=num_images_per_prompt,
|
383 |
+
)
|
384 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
385 |
+
prompt=prompt_2,
|
386 |
+
num_images_per_prompt=num_images_per_prompt,
|
387 |
+
max_sequence_length=max_sequence_length,
|
388 |
+
device=device,
|
389 |
+
)
|
390 |
+
|
391 |
+
if self.text_encoder is not None:
|
392 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
393 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
394 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
395 |
+
|
396 |
+
if self.text_encoder_2 is not None:
|
397 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
398 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
399 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
400 |
+
|
401 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
402 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
403 |
+
|
404 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
405 |
+
|
406 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
|
407 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
408 |
+
if isinstance(generator, list):
|
409 |
+
image_latents = [
|
410 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
411 |
+
for i in range(image.shape[0])
|
412 |
+
]
|
413 |
+
image_latents = torch.cat(image_latents, dim=0)
|
414 |
+
else:
|
415 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
416 |
+
|
417 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
418 |
+
|
419 |
+
return image_latents
|
420 |
+
|
421 |
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
422 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
423 |
+
# get the original timestep using init_timestep
|
424 |
+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
425 |
+
|
426 |
+
t_start = int(max(num_inference_steps - init_timestep, 0))
|
427 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
428 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
429 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
430 |
+
|
431 |
+
return timesteps, num_inference_steps - t_start
|
432 |
+
|
433 |
+
def check_inputs(
|
434 |
+
self,
|
435 |
+
prompt,
|
436 |
+
prompt_2,
|
437 |
+
strength,
|
438 |
+
height,
|
439 |
+
width,
|
440 |
+
prompt_embeds=None,
|
441 |
+
pooled_prompt_embeds=None,
|
442 |
+
callback_on_step_end_tensor_inputs=None,
|
443 |
+
max_sequence_length=None,
|
444 |
+
):
|
445 |
+
if strength < 0 or strength > 1:
|
446 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
447 |
+
|
448 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
449 |
+
logger.warning(
|
450 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
451 |
+
)
|
452 |
+
|
453 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
454 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
455 |
+
):
|
456 |
+
raise ValueError(
|
457 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
458 |
+
)
|
459 |
+
|
460 |
+
if prompt is not None and prompt_embeds is not None:
|
461 |
+
raise ValueError(
|
462 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
463 |
+
" only forward one of the two."
|
464 |
+
)
|
465 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
466 |
+
raise ValueError(
|
467 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
468 |
+
" only forward one of the two."
|
469 |
+
)
|
470 |
+
elif prompt is None and prompt_embeds is None:
|
471 |
+
raise ValueError(
|
472 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
473 |
+
)
|
474 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
475 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
476 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
477 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
478 |
+
|
479 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
480 |
+
raise ValueError(
|
481 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
482 |
+
)
|
483 |
+
|
484 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
485 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
486 |
+
|
487 |
+
@staticmethod
|
488 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
489 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
490 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
491 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
492 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
493 |
+
|
494 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
495 |
+
|
496 |
+
latent_image_ids = latent_image_ids.reshape(
|
497 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
498 |
+
)
|
499 |
+
|
500 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
501 |
+
|
502 |
+
@staticmethod
|
503 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
504 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
505 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
506 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
507 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
508 |
+
|
509 |
+
return latents
|
510 |
+
|
511 |
+
@staticmethod
|
512 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
513 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
514 |
+
batch_size, num_patches, channels = latents.shape
|
515 |
+
|
516 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
517 |
+
# latent height and width to be divisible by 2.
|
518 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
519 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
520 |
+
|
521 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
522 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
523 |
+
|
524 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
525 |
+
|
526 |
+
return latents
|
527 |
+
|
528 |
+
def prepare_latents(
|
529 |
+
self,
|
530 |
+
image,
|
531 |
+
timestep,
|
532 |
+
batch_size,
|
533 |
+
num_channels_latents,
|
534 |
+
height,
|
535 |
+
width,
|
536 |
+
dtype,
|
537 |
+
device,
|
538 |
+
generator,
|
539 |
+
latents=None,
|
540 |
+
):
|
541 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
542 |
+
raise ValueError(
|
543 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
544 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
545 |
+
)
|
546 |
+
|
547 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
548 |
+
# latent height and width to be divisible by 2.
|
549 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
550 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
551 |
+
shape = (batch_size, num_channels_latents, height, width)
|
552 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
553 |
+
|
554 |
+
if latents is not None:
|
555 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
556 |
+
|
557 |
+
image = image.to(device=device, dtype=dtype)
|
558 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
559 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
560 |
+
# expand init_latents for batch_size
|
561 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
562 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
563 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
564 |
+
raise ValueError(
|
565 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
566 |
+
)
|
567 |
+
else:
|
568 |
+
image_latents = torch.cat([image_latents], dim=0)
|
569 |
+
|
570 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
571 |
+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
572 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
573 |
+
return latents, latent_image_ids
|
574 |
+
|
575 |
+
@property
|
576 |
+
def guidance_scale(self):
|
577 |
+
return self._guidance_scale
|
578 |
+
|
579 |
+
@property
|
580 |
+
def joint_attention_kwargs(self):
|
581 |
+
return self._joint_attention_kwargs
|
582 |
+
|
583 |
+
@property
|
584 |
+
def num_timesteps(self):
|
585 |
+
return self._num_timesteps
|
586 |
+
|
587 |
+
@property
|
588 |
+
def interrupt(self):
|
589 |
+
return self._interrupt
|
590 |
+
|
591 |
+
@torch.no_grad()
|
592 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
593 |
+
def __call__(
|
594 |
+
self,
|
595 |
+
prompt: Union[str, List[str]] = None,
|
596 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
597 |
+
image: PipelineImageInput = None,
|
598 |
+
height: Optional[int] = None,
|
599 |
+
width: Optional[int] = None,
|
600 |
+
strength: float = 0.6,
|
601 |
+
num_inference_steps: int = 28,
|
602 |
+
sigmas: Optional[List[float]] = None,
|
603 |
+
guidance_scale: float = 7.0,
|
604 |
+
num_images_per_prompt: Optional[int] = 1,
|
605 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
606 |
+
latents: Optional[torch.FloatTensor] = None,
|
607 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
608 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
609 |
+
output_type: Optional[str] = "pil",
|
610 |
+
return_dict: bool = True,
|
611 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
612 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
613 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
614 |
+
max_sequence_length: int = 512,
|
615 |
+
):
|
616 |
+
r"""
|
617 |
+
Function invoked when calling the pipeline for generation.
|
618 |
+
|
619 |
+
Args:
|
620 |
+
prompt (`str` or `List[str]`, *optional*):
|
621 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
622 |
+
instead.
|
623 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
624 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
625 |
+
will be used instead
|
626 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
627 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
628 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
629 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
630 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
631 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
632 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
633 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
634 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
635 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
636 |
+
strength (`float`, *optional*, defaults to 1.0):
|
637 |
+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
638 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
639 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
640 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
641 |
+
essentially ignores `image`.
|
642 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
643 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
644 |
+
expense of slower inference.
|
645 |
+
sigmas (`List[float]`, *optional*):
|
646 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
647 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
648 |
+
will be used.
|
649 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
650 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
651 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
652 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
653 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
654 |
+
usually at the expense of lower image quality.
|
655 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
656 |
+
The number of images to generate per prompt.
|
657 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
658 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
659 |
+
to make generation deterministic.
|
660 |
+
latents (`torch.FloatTensor`, *optional*):
|
661 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
662 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
663 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
664 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
665 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
666 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
667 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
668 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
669 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
670 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
671 |
+
The output format of the generate image. Choose between
|
672 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
673 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
674 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
675 |
+
joint_attention_kwargs (`dict`, *optional*):
|
676 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
677 |
+
`self.processor` in
|
678 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
679 |
+
callback_on_step_end (`Callable`, *optional*):
|
680 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
681 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
682 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
683 |
+
`callback_on_step_end_tensor_inputs`.
|
684 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
685 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
686 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
687 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
688 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
689 |
+
|
690 |
+
Examples:
|
691 |
+
|
692 |
+
Returns:
|
693 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
694 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
695 |
+
images.
|
696 |
+
"""
|
697 |
+
|
698 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
699 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
700 |
+
|
701 |
+
# 1. Check inputs. Raise error if not correct
|
702 |
+
self.check_inputs(
|
703 |
+
prompt,
|
704 |
+
prompt_2,
|
705 |
+
strength,
|
706 |
+
height,
|
707 |
+
width,
|
708 |
+
prompt_embeds=prompt_embeds,
|
709 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
710 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
711 |
+
max_sequence_length=max_sequence_length,
|
712 |
+
)
|
713 |
+
|
714 |
+
self._guidance_scale = guidance_scale
|
715 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
716 |
+
self._interrupt = False
|
717 |
+
|
718 |
+
# 2. Preprocess image
|
719 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
720 |
+
init_image = init_image.to(dtype=torch.float32)
|
721 |
+
|
722 |
+
# 3. Define call parameters
|
723 |
+
if prompt is not None and isinstance(prompt, str):
|
724 |
+
batch_size = 1
|
725 |
+
elif prompt is not None and isinstance(prompt, list):
|
726 |
+
batch_size = len(prompt)
|
727 |
+
else:
|
728 |
+
batch_size = prompt_embeds.shape[0]
|
729 |
+
|
730 |
+
device = self._execution_device
|
731 |
+
|
732 |
+
lora_scale = (
|
733 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
734 |
+
)
|
735 |
+
(
|
736 |
+
prompt_embeds,
|
737 |
+
pooled_prompt_embeds,
|
738 |
+
text_ids,
|
739 |
+
) = self.encode_prompt(
|
740 |
+
prompt=prompt,
|
741 |
+
prompt_2=prompt_2,
|
742 |
+
prompt_embeds=prompt_embeds,
|
743 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
744 |
+
device=device,
|
745 |
+
num_images_per_prompt=num_images_per_prompt,
|
746 |
+
max_sequence_length=max_sequence_length,
|
747 |
+
lora_scale=lora_scale,
|
748 |
+
)
|
749 |
+
|
750 |
+
# 4.Prepare timesteps
|
751 |
+
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
|
752 |
+
mu = calculate_shift(
|
753 |
+
image_seq_len,
|
754 |
+
self.scheduler.config.base_image_seq_len,
|
755 |
+
self.scheduler.config.max_image_seq_len,
|
756 |
+
self.scheduler.config.base_shift,
|
757 |
+
self.scheduler.config.max_shift,
|
758 |
+
)
|
759 |
+
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
760 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
761 |
+
self.scheduler,
|
762 |
+
num_inference_steps,
|
763 |
+
device,
|
764 |
+
sigmas=sigmas,
|
765 |
+
mu=mu,
|
766 |
+
)
|
767 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
768 |
+
|
769 |
+
if num_inference_steps < 1:
|
770 |
+
raise ValueError(
|
771 |
+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
772 |
+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
773 |
+
)
|
774 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
775 |
+
|
776 |
+
# 5. Prepare latent variables
|
777 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
778 |
+
|
779 |
+
latents, latent_image_ids = self.prepare_latents(
|
780 |
+
init_image,
|
781 |
+
latent_timestep,
|
782 |
+
batch_size * num_images_per_prompt,
|
783 |
+
num_channels_latents,
|
784 |
+
height,
|
785 |
+
width,
|
786 |
+
prompt_embeds.dtype,
|
787 |
+
device,
|
788 |
+
generator,
|
789 |
+
latents,
|
790 |
+
)
|
791 |
+
|
792 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
793 |
+
self._num_timesteps = len(timesteps)
|
794 |
+
|
795 |
+
# handle guidance
|
796 |
+
if self.transformer.config.guidance_embeds:
|
797 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
798 |
+
guidance = guidance.expand(latents.shape[0])
|
799 |
+
else:
|
800 |
+
guidance = None
|
801 |
+
|
802 |
+
# 6. Denoising loop
|
803 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
804 |
+
for i, t in enumerate(timesteps):
|
805 |
+
if self.interrupt:
|
806 |
+
continue
|
807 |
+
|
808 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
809 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
810 |
+
noise_pred = self.transformer(
|
811 |
+
hidden_states=latents,
|
812 |
+
timestep=timestep / 1000,
|
813 |
+
guidance=guidance,
|
814 |
+
pooled_projections=pooled_prompt_embeds,
|
815 |
+
encoder_hidden_states=prompt_embeds,
|
816 |
+
txt_ids=text_ids,
|
817 |
+
img_ids=latent_image_ids,
|
818 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
819 |
+
return_dict=False,
|
820 |
+
)[0]
|
821 |
+
|
822 |
+
# compute the previous noisy sample x_t -> x_t-1
|
823 |
+
latents_dtype = latents.dtype
|
824 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
825 |
+
|
826 |
+
if latents.dtype != latents_dtype:
|
827 |
+
if torch.backends.mps.is_available():
|
828 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
829 |
+
latents = latents.to(latents_dtype)
|
830 |
+
|
831 |
+
if callback_on_step_end is not None:
|
832 |
+
callback_kwargs = {}
|
833 |
+
for k in callback_on_step_end_tensor_inputs:
|
834 |
+
callback_kwargs[k] = locals()[k]
|
835 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
836 |
+
|
837 |
+
latents = callback_outputs.pop("latents", latents)
|
838 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
839 |
+
|
840 |
+
# call the callback, if provided
|
841 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
842 |
+
progress_bar.update()
|
843 |
+
|
844 |
+
if XLA_AVAILABLE:
|
845 |
+
xm.mark_step()
|
846 |
+
|
847 |
+
if output_type == "latent":
|
848 |
+
image = latents
|
849 |
+
|
850 |
+
else:
|
851 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
852 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
853 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
854 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
855 |
+
|
856 |
+
# Offload all models
|
857 |
+
self.maybe_free_model_hooks()
|
858 |
+
|
859 |
+
if not return_dict:
|
860 |
+
return (image,)
|
861 |
+
|
862 |
+
return FluxPipelineOutput(images=image)
|
pipeline/custom_pipelines/pipeline_flux_prior_redux.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copied from diffusers/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py
|
2 |
+
|
3 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
from typing import List, Optional, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from PIL import Image
|
22 |
+
from transformers import (
|
23 |
+
CLIPTextModel,
|
24 |
+
CLIPTokenizer,
|
25 |
+
SiglipImageProcessor,
|
26 |
+
SiglipVisionModel,
|
27 |
+
T5EncoderModel,
|
28 |
+
T5TokenizerFast,
|
29 |
+
)
|
30 |
+
|
31 |
+
from diffusers.image_processor import PipelineImageInput
|
32 |
+
from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
33 |
+
from diffusers.utils import (
|
34 |
+
USE_PEFT_BACKEND,
|
35 |
+
is_torch_xla_available,
|
36 |
+
logging,
|
37 |
+
replace_example_docstring,
|
38 |
+
scale_lora_layers,
|
39 |
+
unscale_lora_layers,
|
40 |
+
)
|
41 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
42 |
+
from diffusers.pipelines.flux.modeling_flux import ReduxImageEncoder
|
43 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPriorReduxPipelineOutput
|
44 |
+
|
45 |
+
|
46 |
+
if is_torch_xla_available():
|
47 |
+
XLA_AVAILABLE = True
|
48 |
+
else:
|
49 |
+
XLA_AVAILABLE = False
|
50 |
+
|
51 |
+
|
52 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
53 |
+
|
54 |
+
EXAMPLE_DOC_STRING = """
|
55 |
+
Examples:
|
56 |
+
```py
|
57 |
+
>>> import torch
|
58 |
+
>>> from diffusers import FluxPriorReduxPipeline, FluxPipeline
|
59 |
+
>>> from diffusers.utils import load_image
|
60 |
+
|
61 |
+
>>> device = "cuda"
|
62 |
+
>>> dtype = torch.bfloat16
|
63 |
+
|
64 |
+
>>> repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
|
65 |
+
>>> repo_base = "black-forest-labs/FLUX.1-dev"
|
66 |
+
>>> pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)
|
67 |
+
>>> pipe = FluxPipeline.from_pretrained(
|
68 |
+
... repo_base, text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16
|
69 |
+
... ).to(device)
|
70 |
+
|
71 |
+
>>> image = load_image(
|
72 |
+
... "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png"
|
73 |
+
... )
|
74 |
+
>>> pipe_prior_output = pipe_prior_redux(image)
|
75 |
+
>>> images = pipe(
|
76 |
+
... guidance_scale=2.5,
|
77 |
+
... num_inference_steps=50,
|
78 |
+
... generator=torch.Generator("cpu").manual_seed(0),
|
79 |
+
... **pipe_prior_output,
|
80 |
+
... ).images
|
81 |
+
>>> images[0].save("flux-redux.png")
|
82 |
+
```
|
83 |
+
"""
|
84 |
+
|
85 |
+
|
86 |
+
class FluxPriorReduxPipeline(DiffusionPipeline):
|
87 |
+
r"""
|
88 |
+
The Flux Redux pipeline for image-to-image generation.
|
89 |
+
|
90 |
+
Reference: https://blackforestlabs.ai/flux-1-tools/
|
91 |
+
|
92 |
+
Args:
|
93 |
+
image_encoder ([`SiglipVisionModel`]):
|
94 |
+
SIGLIP vision model to encode the input image.
|
95 |
+
feature_extractor ([`SiglipImageProcessor`]):
|
96 |
+
Image processor for preprocessing images for the SIGLIP model.
|
97 |
+
image_embedder ([`ReduxImageEncoder`]):
|
98 |
+
Redux image encoder to process the SIGLIP embeddings.
|
99 |
+
text_encoder ([`CLIPTextModel`], *optional*):
|
100 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
101 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
102 |
+
text_encoder_2 ([`T5EncoderModel`], *optional*):
|
103 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
104 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
105 |
+
tokenizer (`CLIPTokenizer`, *optional*):
|
106 |
+
Tokenizer of class
|
107 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
108 |
+
tokenizer_2 (`T5TokenizerFast`, *optional*):
|
109 |
+
Second Tokenizer of class
|
110 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
111 |
+
"""
|
112 |
+
|
113 |
+
model_cpu_offload_seq = "image_encoder->image_embedder"
|
114 |
+
_optional_components = [
|
115 |
+
"text_encoder",
|
116 |
+
"tokenizer",
|
117 |
+
"text_encoder_2",
|
118 |
+
"tokenizer_2",
|
119 |
+
]
|
120 |
+
_callback_tensor_inputs = []
|
121 |
+
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
image_encoder: SiglipVisionModel,
|
125 |
+
feature_extractor: SiglipImageProcessor,
|
126 |
+
image_embedder: ReduxImageEncoder,
|
127 |
+
text_encoder: CLIPTextModel = None,
|
128 |
+
tokenizer: CLIPTokenizer = None,
|
129 |
+
text_encoder_2: T5EncoderModel = None,
|
130 |
+
tokenizer_2: T5TokenizerFast = None,
|
131 |
+
):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
self.register_modules(
|
135 |
+
image_encoder=image_encoder,
|
136 |
+
feature_extractor=feature_extractor,
|
137 |
+
image_embedder=image_embedder,
|
138 |
+
text_encoder=text_encoder,
|
139 |
+
tokenizer=tokenizer,
|
140 |
+
text_encoder_2=text_encoder_2,
|
141 |
+
tokenizer_2=tokenizer_2,
|
142 |
+
)
|
143 |
+
self.tokenizer_max_length = (
|
144 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
145 |
+
)
|
146 |
+
|
147 |
+
def check_inputs(
|
148 |
+
self,
|
149 |
+
image,
|
150 |
+
prompt,
|
151 |
+
prompt_2,
|
152 |
+
prompt_embeds=None,
|
153 |
+
pooled_prompt_embeds=None,
|
154 |
+
prompt_embeds_scale=1.0,
|
155 |
+
pooled_prompt_embeds_scale=1.0,
|
156 |
+
):
|
157 |
+
if prompt is not None and prompt_embeds is not None:
|
158 |
+
raise ValueError(
|
159 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
160 |
+
" only forward one of the two."
|
161 |
+
)
|
162 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
163 |
+
raise ValueError(
|
164 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
165 |
+
" only forward one of the two."
|
166 |
+
)
|
167 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
168 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
169 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
170 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
171 |
+
if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)):
|
172 |
+
raise ValueError(
|
173 |
+
f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images"
|
174 |
+
)
|
175 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
176 |
+
raise ValueError(
|
177 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
178 |
+
)
|
179 |
+
if isinstance(prompt_embeds_scale, list) and (
|
180 |
+
isinstance(image, list) and len(prompt_embeds_scale) != len(image)
|
181 |
+
):
|
182 |
+
raise ValueError(
|
183 |
+
f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images"
|
184 |
+
)
|
185 |
+
|
186 |
+
def encode_image(self, image, device, num_images_per_prompt):
|
187 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
188 |
+
image = self.feature_extractor.preprocess(
|
189 |
+
images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True
|
190 |
+
)
|
191 |
+
image = image.to(device=device, dtype=dtype)
|
192 |
+
|
193 |
+
image_enc_hidden_states = self.image_encoder(**image).last_hidden_state
|
194 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
195 |
+
|
196 |
+
return image_enc_hidden_states
|
197 |
+
|
198 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
199 |
+
def _get_t5_prompt_embeds(
|
200 |
+
self,
|
201 |
+
prompt: Union[str, List[str]] = None,
|
202 |
+
num_images_per_prompt: int = 1,
|
203 |
+
max_sequence_length: int = 512,
|
204 |
+
device: Optional[torch.device] = None,
|
205 |
+
dtype: Optional[torch.dtype] = None,
|
206 |
+
):
|
207 |
+
device = device or self._execution_device
|
208 |
+
dtype = dtype or self.text_encoder.dtype
|
209 |
+
|
210 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
211 |
+
batch_size = len(prompt)
|
212 |
+
|
213 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
214 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
215 |
+
|
216 |
+
text_inputs = self.tokenizer_2(
|
217 |
+
prompt,
|
218 |
+
padding="max_length",
|
219 |
+
max_length=max_sequence_length,
|
220 |
+
truncation=True,
|
221 |
+
return_length=False,
|
222 |
+
return_overflowing_tokens=False,
|
223 |
+
return_tensors="pt",
|
224 |
+
)
|
225 |
+
text_input_ids = text_inputs.input_ids
|
226 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
227 |
+
|
228 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
229 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
230 |
+
logger.warning(
|
231 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
232 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
233 |
+
)
|
234 |
+
|
235 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
236 |
+
|
237 |
+
dtype = self.text_encoder_2.dtype
|
238 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
239 |
+
|
240 |
+
_, seq_len, _ = prompt_embeds.shape
|
241 |
+
|
242 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
243 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
244 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
245 |
+
|
246 |
+
return prompt_embeds
|
247 |
+
|
248 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
|
249 |
+
def _get_clip_prompt_embeds(
|
250 |
+
self,
|
251 |
+
prompt: Union[str, List[str]],
|
252 |
+
num_images_per_prompt: int = 1,
|
253 |
+
device: Optional[torch.device] = None,
|
254 |
+
):
|
255 |
+
device = device or self._execution_device
|
256 |
+
|
257 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
258 |
+
batch_size = len(prompt)
|
259 |
+
|
260 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
261 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
262 |
+
|
263 |
+
text_inputs = self.tokenizer(
|
264 |
+
prompt,
|
265 |
+
padding="max_length",
|
266 |
+
max_length=self.tokenizer_max_length,
|
267 |
+
truncation=True,
|
268 |
+
return_overflowing_tokens=False,
|
269 |
+
return_length=False,
|
270 |
+
return_tensors="pt",
|
271 |
+
)
|
272 |
+
|
273 |
+
text_input_ids = text_inputs.input_ids
|
274 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
275 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
276 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
277 |
+
logger.warning(
|
278 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
279 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
280 |
+
)
|
281 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
282 |
+
|
283 |
+
# Use pooled output of CLIPTextModel
|
284 |
+
prompt_embeds = prompt_embeds.pooler_output
|
285 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
286 |
+
|
287 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
288 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
289 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
290 |
+
|
291 |
+
return prompt_embeds
|
292 |
+
|
293 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
|
294 |
+
def encode_prompt(
|
295 |
+
self,
|
296 |
+
prompt: Union[str, List[str]],
|
297 |
+
prompt_2: Union[str, List[str]],
|
298 |
+
device: Optional[torch.device] = None,
|
299 |
+
num_images_per_prompt: int = 1,
|
300 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
301 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
302 |
+
max_sequence_length: int = 512,
|
303 |
+
lora_scale: Optional[float] = None,
|
304 |
+
):
|
305 |
+
r"""
|
306 |
+
|
307 |
+
Args:
|
308 |
+
prompt (`str` or `List[str]`, *optional*):
|
309 |
+
prompt to be encoded
|
310 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
311 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
312 |
+
used in all text-encoders
|
313 |
+
device: (`torch.device`):
|
314 |
+
torch device
|
315 |
+
num_images_per_prompt (`int`):
|
316 |
+
number of images that should be generated per prompt
|
317 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
318 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
319 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
320 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
321 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
322 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
323 |
+
lora_scale (`float`, *optional*):
|
324 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
325 |
+
"""
|
326 |
+
device = device or self._execution_device
|
327 |
+
|
328 |
+
# set lora scale so that monkey patched LoRA
|
329 |
+
# function of text encoder can correctly access it
|
330 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
331 |
+
self._lora_scale = lora_scale
|
332 |
+
|
333 |
+
# dynamically adjust the LoRA scale
|
334 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
335 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
336 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
337 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
338 |
+
|
339 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
340 |
+
|
341 |
+
if prompt_embeds is None:
|
342 |
+
prompt_2 = prompt_2 or prompt
|
343 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
344 |
+
|
345 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
346 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
347 |
+
prompt=prompt,
|
348 |
+
device=device,
|
349 |
+
num_images_per_prompt=num_images_per_prompt,
|
350 |
+
)
|
351 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
352 |
+
prompt=prompt_2,
|
353 |
+
num_images_per_prompt=num_images_per_prompt,
|
354 |
+
max_sequence_length=max_sequence_length,
|
355 |
+
device=device,
|
356 |
+
)
|
357 |
+
|
358 |
+
if self.text_encoder is not None:
|
359 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
360 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
361 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
362 |
+
|
363 |
+
if self.text_encoder_2 is not None:
|
364 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
365 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
366 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
367 |
+
|
368 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
369 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
370 |
+
|
371 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
372 |
+
|
373 |
+
@torch.no_grad()
|
374 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
375 |
+
def __call__(
|
376 |
+
self,
|
377 |
+
image: PipelineImageInput,
|
378 |
+
prompt: Union[str, List[str]] = None,
|
379 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
380 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
381 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
382 |
+
prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
|
383 |
+
pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
|
384 |
+
strength: Optional[Union[float, List[float]]] = 1.0,
|
385 |
+
return_dict: bool = True,
|
386 |
+
):
|
387 |
+
r"""
|
388 |
+
Function invoked when calling the pipeline for generation.
|
389 |
+
|
390 |
+
Args:
|
391 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
392 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
393 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
394 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
395 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
|
396 |
+
prompt (`str` or `List[str]`, *optional*):
|
397 |
+
The prompt or prompts to guide the image generation. **experimental feature**: to use this feature,
|
398 |
+
make sure to explicitly load text encoders to the pipeline. Prompts will be ignored if text encoders
|
399 |
+
are not loaded.
|
400 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
401 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
|
402 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
403 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
404 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
405 |
+
Pre-generated pooled text embeddings.
|
406 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
407 |
+
Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple.
|
408 |
+
|
409 |
+
Examples:
|
410 |
+
|
411 |
+
Returns:
|
412 |
+
[`~pipelines.flux.FluxPriorReduxPipelineOutput`] or `tuple`:
|
413 |
+
[`~pipelines.flux.FluxPriorReduxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
414 |
+
returning a tuple, the first element is a list with the generated images.
|
415 |
+
"""
|
416 |
+
|
417 |
+
# 1. Check inputs. Raise error if not correct
|
418 |
+
self.check_inputs(
|
419 |
+
image,
|
420 |
+
prompt,
|
421 |
+
prompt_2,
|
422 |
+
prompt_embeds=prompt_embeds,
|
423 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
424 |
+
prompt_embeds_scale=prompt_embeds_scale,
|
425 |
+
pooled_prompt_embeds_scale=pooled_prompt_embeds_scale,
|
426 |
+
)
|
427 |
+
|
428 |
+
# 2. Define call parameters
|
429 |
+
if image is not None and isinstance(image, Image.Image):
|
430 |
+
batch_size = 1
|
431 |
+
elif image is not None and isinstance(image, list):
|
432 |
+
batch_size = len(image)
|
433 |
+
else:
|
434 |
+
batch_size = image.shape[0]
|
435 |
+
if prompt is not None and isinstance(prompt, str):
|
436 |
+
prompt = batch_size * [prompt]
|
437 |
+
if isinstance(prompt_embeds_scale, float):
|
438 |
+
prompt_embeds_scale = batch_size * [prompt_embeds_scale]
|
439 |
+
if isinstance(pooled_prompt_embeds_scale, float):
|
440 |
+
pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale]
|
441 |
+
if isinstance(strength, float):
|
442 |
+
strength = batch_size * [strength]
|
443 |
+
|
444 |
+
device = self._execution_device
|
445 |
+
|
446 |
+
# 3. Prepare image embeddings
|
447 |
+
image_latents = self.encode_image(image, device, 1)
|
448 |
+
|
449 |
+
image_embeds = self.image_embedder(image_latents).image_embeds
|
450 |
+
image_embeds = image_embeds.to(device=device)
|
451 |
+
|
452 |
+
# 3. Prepare (dummy) text embeddings
|
453 |
+
if hasattr(self, "text_encoder") and self.text_encoder is not None:
|
454 |
+
(
|
455 |
+
prompt_embeds,
|
456 |
+
pooled_prompt_embeds,
|
457 |
+
_,
|
458 |
+
) = self.encode_prompt(
|
459 |
+
prompt=prompt,
|
460 |
+
prompt_2=prompt_2,
|
461 |
+
prompt_embeds=prompt_embeds,
|
462 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
463 |
+
device=device,
|
464 |
+
num_images_per_prompt=1,
|
465 |
+
max_sequence_length=512,
|
466 |
+
lora_scale=None,
|
467 |
+
)
|
468 |
+
else:
|
469 |
+
if prompt is not None:
|
470 |
+
logger.warning(
|
471 |
+
"prompt input is ignored when text encoders are not loaded to the pipeline. "
|
472 |
+
"Make sure to explicitly load the text encoders to enable prompt input. "
|
473 |
+
)
|
474 |
+
# max_sequence_length is 512, t5 encoder hidden size is 4096
|
475 |
+
prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype)
|
476 |
+
# pooled_prompt_embeds is 768, clip text encoder hidden size
|
477 |
+
pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype)
|
478 |
+
|
479 |
+
# apply strength to image_embeds
|
480 |
+
image_embeds *= torch.tensor(strength, device=device, dtype=image_embeds.dtype)[:, None, None]
|
481 |
+
|
482 |
+
# scale & concatenate image and text embeddings
|
483 |
+
prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1)
|
484 |
+
|
485 |
+
prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None]
|
486 |
+
pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[
|
487 |
+
:, None
|
488 |
+
]
|
489 |
+
|
490 |
+
# weighted sum
|
491 |
+
prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True)
|
492 |
+
pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True)
|
493 |
+
|
494 |
+
# Offload all models
|
495 |
+
self.maybe_free_model_hooks()
|
496 |
+
|
497 |
+
if not return_dict:
|
498 |
+
return (prompt_embeds, pooled_prompt_embeds)
|
499 |
+
|
500 |
+
return FluxPriorReduxPipelineOutput(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds)
|
pipeline/example_text_to_3d.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pipeline.kiss3d_wrapper import init_wrapper_from_config, run_text_to_3d, run_image_to_3d
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
k3d_wrapper = init_wrapper_from_config('/hpc2hdd/home/jlin695/code/github/Kiss3DGen/pipeline/pipeline_config/default.yaml')
|
5 |
+
|
6 |
+
run_text_to_3d(k3d_wrapper, prompt='A doll of a girl in Harry Potter')
|
7 |
+
|
pipeline/kiss3d_wrapper.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
# The kiss3d pipeline wrapper for inference
|
2 |
|
3 |
import os
|
|
|
4 |
import numpy as np
|
|
|
5 |
import torch
|
6 |
import yaml
|
7 |
import uuid
|
@@ -10,49 +12,93 @@ from einops import rearrange
|
|
10 |
from PIL import Image
|
11 |
|
12 |
from pipeline.utils import logger, TMP_DIR, OUT_DIR
|
13 |
-
from pipeline.utils import lrm_reconstruct, isomer_reconstruct
|
14 |
|
15 |
import torch
|
16 |
import torchvision
|
|
|
17 |
|
18 |
# for reconstruction model
|
19 |
from omegaconf import OmegaConf
|
20 |
from models.lrm.utils.train_util import instantiate_from_config
|
21 |
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
|
|
22 |
from utils.tool import get_background
|
23 |
-
|
24 |
# for florence2
|
25 |
-
from transformers import AutoProcessor, AutoModelForCausalLM
|
26 |
-
|
27 |
-
|
28 |
-
from
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
|
|
31 |
def init_wrapper_from_config(config_path):
|
32 |
with open(config_path, 'r') as config_file:
|
33 |
config_ = yaml.load(config_file, yaml.FullLoader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
# init flux_pipeline
|
36 |
logger.info('==> Loading Flux model ...')
|
37 |
flux_device = config_['flux'].get('device', 'cpu')
|
38 |
flux_base_model_pth = config_['flux'].get('base_model', None)
|
|
|
39 |
flux_controlnet_pth = config_['flux'].get('controlnet', None)
|
40 |
-
flux_lora_pth = config_['flux'].get('lora', None)
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
flux_pipe =
|
46 |
-
torch_dtype=torch.bfloat16)
|
47 |
else:
|
48 |
-
flux_pipe = FluxImg2ImgPipeline(flux_base_model_pth, torch_dtype=
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# load lora weights
|
51 |
flux_pipe.load_lora_weights(flux_lora_pth)
|
52 |
-
flux_pipe.to(device=flux_device
|
53 |
|
54 |
-
#
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
# TODO: load pulid model
|
58 |
|
@@ -68,13 +114,15 @@ def init_wrapper_from_config(config_path):
|
|
68 |
multiview_pipeline.scheduler.config, timestep_spacing='trailing'
|
69 |
)
|
70 |
|
71 |
-
unet_ckpt_path = config_['multiview'].get('unet', None)
|
|
|
72 |
if unet_ckpt_path is not None:
|
73 |
state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
|
74 |
state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
|
75 |
multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
|
76 |
|
77 |
multiview_pipeline.to(multiview_device)
|
|
|
78 |
|
79 |
# load caption model
|
80 |
logger.info('==> Loading caption model ...')
|
@@ -82,6 +130,7 @@ def init_wrapper_from_config(config_path):
|
|
82 |
caption_model = AutoModelForCausalLM.from_pretrained(config_['caption']['base_model'], \
|
83 |
torch_dtype=torch.bfloat16, trust_remote_code=True).to(caption_device)
|
84 |
caption_processor = AutoProcessor.from_pretrained(config_['caption']['base_model'], trust_remote_code=True)
|
|
|
85 |
|
86 |
# load reconstruction model
|
87 |
logger.info('==> Loading reconstruction model ...')
|
@@ -89,40 +138,79 @@ def init_wrapper_from_config(config_path):
|
|
89 |
recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
|
90 |
recon_model = instantiate_from_config(recon_model_config.model_config)
|
91 |
# load recon model checkpoint
|
92 |
-
|
|
|
93 |
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
94 |
recon_model.load_state_dict(state_dict, strict=True)
|
95 |
recon_model.to(recon_device)
|
96 |
recon_model.init_flexicubes_geometry(recon_device, fovy=50.0)
|
97 |
recon_model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
return kiss3d_wrapper(
|
100 |
config = config_,
|
101 |
flux_pipeline = flux_pipe,
|
|
|
102 |
multiview_pipeline = multiview_pipeline,
|
103 |
caption_processor = caption_processor,
|
104 |
caption_model = caption_model,
|
105 |
reconstruction_model_config = recon_model_config,
|
106 |
reconstruction_model = recon_model,
|
|
|
|
|
107 |
)
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
class kiss3d_wrapper(object):
|
110 |
def __init__(self,
|
111 |
config: Dict,
|
112 |
flux_pipeline: Union[FluxPipeline, FluxControlNetImg2ImgPipeline],
|
|
|
113 |
multiview_pipeline: DiffusionPipeline,
|
114 |
caption_processor: AutoProcessor,
|
115 |
caption_model: AutoModelForCausalLM,
|
116 |
reconstruction_model_config: Any,
|
117 |
reconstruction_model: Any,
|
|
|
|
|
118 |
):
|
119 |
self.config = config
|
120 |
self.flux_pipeline = flux_pipeline
|
|
|
121 |
self.multiview_pipeline = multiview_pipeline
|
122 |
self.caption_model = caption_model
|
123 |
self.caption_processor = caption_processor
|
124 |
self.recon_model_config = reconstruction_model_config
|
125 |
-
self.recon_model = reconstruction_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
self.renew_uuid()
|
128 |
|
@@ -144,12 +232,10 @@ class kiss3d_wrapper(object):
|
|
144 |
caption_device = self.config['caption'].get('device', 'cpu')
|
145 |
|
146 |
if isinstance(image, str): # If image is a file path
|
147 |
-
image = Image.open(image)
|
148 |
-
elif isinstance(image, Image):
|
149 |
-
image = image.convert("RGB")
|
150 |
-
else:
|
151 |
raise NotImplementedError('unexpected image type')
|
152 |
-
|
153 |
prompt = "<MORE_DETAILED_CAPTION>"
|
154 |
inputs = self.caption_processor(text=prompt, images=image, return_tensors="pt").to(caption_device, torch_dtype)
|
155 |
|
@@ -161,17 +247,45 @@ class kiss3d_wrapper(object):
|
|
161 |
parsed_answer = self.caption_processor.post_process_generation(
|
162 |
generated_text, task=prompt, image_size=(image.width, image.height)
|
163 |
)
|
164 |
-
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"].replace("The image is ", "")
|
|
|
|
|
|
|
|
|
|
|
165 |
return caption_text
|
166 |
|
167 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
with self.context():
|
169 |
mv_image = self.multiview_pipeline(image,
|
170 |
-
num_inference_steps=self.config['multiview']['num_inference_steps'],
|
171 |
-
width=512*2,
|
|
|
|
|
172 |
return mv_image
|
173 |
|
174 |
-
def reconstruct_from_multiview(self, mv_image):
|
175 |
"""
|
176 |
mv_image: PIL.Image
|
177 |
"""
|
@@ -184,23 +298,31 @@ class kiss3d_wrapper(object):
|
|
184 |
with self.context():
|
185 |
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \
|
186 |
lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config,
|
187 |
-
rgb_multi_view, name=self.uuid)
|
188 |
|
189 |
-
return vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo
|
190 |
|
191 |
-
def generate_reference_3D_bundle_image_zero123(self, image, save_intermediate_results=True):
|
192 |
"""
|
193 |
input: image, PIL.Image
|
194 |
-
return: ref_3D_bundle_image, Tensor of shape (
|
195 |
"""
|
196 |
mv_image = self.generate_multiview(image)
|
197 |
|
198 |
if save_intermediate_results:
|
199 |
mv_image.save(os.path.join(TMP_DIR, f'{self.uuid}_mv_image.png'))
|
200 |
|
201 |
-
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = self.reconstruct_from_multiview(mv_image)
|
|
|
|
|
|
|
202 |
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
if save_intermediate_results:
|
206 |
save_path = os.path.join(TMP_DIR, f'{self.uuid}_ref_3d_bundle_image.png')
|
@@ -222,6 +344,9 @@ class kiss3d_wrapper(object):
|
|
222 |
control_guidance_end=None,
|
223 |
controlnet_conditioning_scale=None,
|
224 |
lora_scale=1.0,
|
|
|
|
|
|
|
225 |
save_intermediate_results=True,
|
226 |
**kwargs):
|
227 |
control_mode_dict = {
|
@@ -235,15 +360,20 @@ class kiss3d_wrapper(object):
|
|
235 |
} # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
|
236 |
|
237 |
flux_device = self.config['flux'].get('device', 'cpu')
|
238 |
-
seed = self.config['flux'].get('seed', 0)
|
|
|
239 |
|
240 |
generator = torch.Generator(device=flux_device).manual_seed(seed)
|
241 |
|
|
|
|
|
|
|
242 |
hparam_dict = {
|
243 |
-
'prompt': '
|
244 |
-
'
|
|
|
245 |
'strength': strength,
|
246 |
-
'num_inference_steps':
|
247 |
'guidance_scale': 3.5,
|
248 |
'num_images_per_prompt': 1,
|
249 |
'width': 2048,
|
@@ -253,14 +383,29 @@ class kiss3d_wrapper(object):
|
|
253 |
'joint_attention_kwargs': {"scale": lora_scale}
|
254 |
}
|
255 |
hparam_dict.update(kwargs)
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
# append controlnet hparams
|
258 |
if len(control_image) > 0:
|
259 |
assert isinstance(self.flux_pipeline, FluxControlNetImg2ImgPipeline)
|
260 |
assert len(control_mode) == len(control_image) # the count of image should be the same as control mode
|
261 |
|
262 |
flux_ctrl_net = self.flux_pipeline.controlnet.nets[0]
|
263 |
-
self.flux_pipeline.controlnet = FluxMultiControlNetModel([flux_ctrl_net for
|
264 |
|
265 |
ctrl_hparams = {
|
266 |
'control_mode': [control_mode_dict[mode_] for mode_ in control_mode],
|
@@ -285,13 +430,45 @@ class kiss3d_wrapper(object):
|
|
285 |
|
286 |
return gen_3d_bundle_image_
|
287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
|
289 |
def generate_3d_bundle_image_text(self,
|
290 |
prompt,
|
291 |
image=None,
|
292 |
strength=1.0,
|
293 |
lora_scale=1.0,
|
294 |
-
num_inference_steps=
|
|
|
|
|
295 |
save_intermediate_results=True,
|
296 |
**kwargs):
|
297 |
|
@@ -299,27 +476,25 @@ class kiss3d_wrapper(object):
|
|
299 |
return: gen_3d_bundle_image, torch.Tensor of shape (3, 1024, 2048), range [0., 1.]
|
300 |
"""
|
301 |
|
302 |
-
if isinstance(self.flux_pipeline,
|
303 |
-
flux_pipeline = FluxImg2ImgPipeline(
|
304 |
-
scheduler = self.flux_pipeline.scheduler,
|
305 |
-
vae = self.flux_pipeline.vae,
|
306 |
-
text_encoder = self.flux_pipeline.text_encoder,
|
307 |
-
tokenizer = self.flux_pipeline.tokenizer,
|
308 |
-
text_encoder_2 = self.flux_pipeline.text_encoder_2,
|
309 |
-
tokenizer_2 = self.flux_pipeline.tokenizer_2,
|
310 |
-
transformer = self.flux_pipeline.transformer
|
311 |
-
)
|
312 |
-
else:
|
313 |
flux_pipeline = self.flux_pipeline
|
|
|
|
|
314 |
|
315 |
flux_device = self.config['flux'].get('device', 'cpu')
|
316 |
-
seed = self.config['flux'].get('seed', 0)
|
|
|
|
|
|
|
|
|
317 |
|
318 |
generator = torch.Generator(device=flux_device).manual_seed(seed)
|
319 |
|
|
|
320 |
hparam_dict = {
|
321 |
-
'prompt': '
|
322 |
-
'
|
|
|
323 |
'strength': strength,
|
324 |
'num_inference_steps': num_inference_steps,
|
325 |
'guidance_scale': 3.5,
|
@@ -332,6 +507,22 @@ class kiss3d_wrapper(object):
|
|
332 |
}
|
333 |
hparam_dict.update(kwargs)
|
334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
with self.context():
|
336 |
gen_3d_bundle_image = flux_pipeline(**hparam_dict).images
|
337 |
|
@@ -345,7 +536,13 @@ class kiss3d_wrapper(object):
|
|
345 |
|
346 |
return gen_3d_bundle_image_
|
347 |
|
348 |
-
def reconstruct_3d_bundle_image(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
"""
|
350 |
image: torch.Tensor, range [0., 1.], (3, 1024, 2048)
|
351 |
"""
|
@@ -355,6 +552,8 @@ class kiss3d_wrapper(object):
|
|
355 |
images = rearrange(image, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (3, 1024, 2048) -> (8, 3, 512, 512)
|
356 |
rgb_multi_view, normal_multi_view = images.chunk(2, dim=0)
|
357 |
multi_view_mask = get_background(normal_multi_view).to(recon_device)
|
|
|
|
|
358 |
rgb_multi_view = rgb_multi_view.to(recon_device) * multi_view_mask + (1 - multi_view_mask)
|
359 |
|
360 |
with self.context():
|
@@ -362,11 +561,12 @@ class kiss3d_wrapper(object):
|
|
362 |
lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config,
|
363 |
rgb_multi_view.unsqueeze(0).to(recon_device), name=self.uuid,
|
364 |
input_camera_type='kiss3d', render_3d_bundle_image=save_intermediate_results,
|
365 |
-
render_azimuths=[0, 90, 180, 270]
|
|
|
366 |
|
367 |
if save_intermediate_results:
|
368 |
recon_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
|
369 |
-
torchvision.utils.save_image(recon_3D_bundle_image, os.path.join(TMP_DIR, f'{
|
370 |
|
371 |
recon_mesh_path = os.path.join(TMP_DIR, f"{self.uuid}_isomer_recon_mesh.obj")
|
372 |
|
@@ -375,7 +575,11 @@ class kiss3d_wrapper(object):
|
|
375 |
multi_view_mask=multi_view_mask,
|
376 |
vertices=vertices,
|
377 |
faces=faces,
|
378 |
-
save_path=recon_mesh_path
|
|
|
|
|
|
|
|
|
379 |
|
380 |
|
381 |
def run_text_to_3d(k3d_wrapper,
|
@@ -391,39 +595,176 @@ def run_text_to_3d(k3d_wrapper,
|
|
391 |
if init_image_path is not None:
|
392 |
init_image = Image.open(init_image_path)
|
393 |
|
|
|
|
|
|
|
|
|
|
|
394 |
gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_text(prompt,
|
395 |
-
|
396 |
-
|
397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
|
399 |
# recon from 3D Bundle image
|
400 |
recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, save_intermediate_results=False)
|
401 |
|
402 |
return gen_save_path, recon_mesh_path
|
403 |
|
404 |
-
|
|
|
405 |
# ======================================= Example of image to 3D generation ======================================
|
406 |
|
407 |
# Renew The uuid
|
408 |
k3d_wrapper.renew_uuid()
|
409 |
|
410 |
# FOR IMAGE TO 3D: generate reference 3D bundle image from a single input image
|
411 |
-
input_image = Image.open(
|
412 |
-
|
|
|
|
|
413 |
caption = k3d_wrapper.get_image_caption(input_image)
|
414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
|
416 |
-
|
417 |
-
pdb.set_trace()
|
418 |
|
419 |
|
420 |
if __name__ == "__main__":
|
421 |
-
k3d_wrapper = init_wrapper_from_config('/hpc2hdd/home/jlin695/code/Kiss3DGen/pipeline/pipeline_config/default.yaml')
|
|
|
|
|
|
|
422 |
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
|
427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
# run_text_to_3d(k3d_wrapper, prompt='A doll of a girl in Harry Potter')
|
429 |
|
|
|
|
|
|
|
|
|
|
1 |
# The kiss3d pipeline wrapper for inference
|
2 |
|
3 |
import os
|
4 |
+
import spaces
|
5 |
import numpy as np
|
6 |
+
import random
|
7 |
import torch
|
8 |
import yaml
|
9 |
import uuid
|
|
|
12 |
from PIL import Image
|
13 |
|
14 |
from pipeline.utils import logger, TMP_DIR, OUT_DIR
|
15 |
+
from pipeline.utils import lrm_reconstruct, isomer_reconstruct, preprocess_input_image
|
16 |
|
17 |
import torch
|
18 |
import torchvision
|
19 |
+
from torch.nn import functional as F
|
20 |
|
21 |
# for reconstruction model
|
22 |
from omegaconf import OmegaConf
|
23 |
from models.lrm.utils.train_util import instantiate_from_config
|
24 |
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
25 |
+
#
|
26 |
from utils.tool import get_background
|
|
|
27 |
# for florence2
|
28 |
+
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer
|
29 |
+
from models.llm.llm import load_llm_model, get_llm_response
|
30 |
+
|
31 |
+
from pipeline.custom_pipelines import FluxPriorReduxPipeline, FluxControlNetImg2ImgPipeline, FluxImg2ImgPipeline
|
32 |
+
from diffusers import FluxPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler, FluxTransformer2DModel
|
33 |
+
from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel, FluxControlNetModel
|
34 |
+
from diffusers.schedulers import FlowMatchHeunDiscreteScheduler
|
35 |
+
from huggingface_hub import hf_hub_download
|
36 |
+
access_token = os.getenv("HUGGINGFACE_TOKEN")
|
37 |
+
|
38 |
+
|
39 |
+
def convert_flux_pipeline(exist_flux_pipe, target_pipe, **kwargs):
|
40 |
+
new_pipe = target_pipe(
|
41 |
+
scheduler = exist_flux_pipe.scheduler,
|
42 |
+
vae = exist_flux_pipe.vae,
|
43 |
+
text_encoder = exist_flux_pipe.text_encoder,
|
44 |
+
tokenizer = exist_flux_pipe.tokenizer,
|
45 |
+
text_encoder_2 = exist_flux_pipe.text_encoder_2,
|
46 |
+
tokenizer_2 = exist_flux_pipe.tokenizer_2,
|
47 |
+
transformer = exist_flux_pipe.transformer,
|
48 |
+
**kwargs
|
49 |
+
)
|
50 |
+
return new_pipe
|
51 |
|
52 |
+
@spaces.GPU
|
53 |
def init_wrapper_from_config(config_path):
|
54 |
with open(config_path, 'r') as config_file:
|
55 |
config_ = yaml.load(config_file, yaml.FullLoader)
|
56 |
+
|
57 |
+
dtype_ = {
|
58 |
+
'fp8': torch.float8_e4m3fn,
|
59 |
+
'bf16': torch.bfloat16,
|
60 |
+
'fp16': torch.float16,
|
61 |
+
'fp32': torch.float32
|
62 |
+
}
|
63 |
|
64 |
# init flux_pipeline
|
65 |
logger.info('==> Loading Flux model ...')
|
66 |
flux_device = config_['flux'].get('device', 'cpu')
|
67 |
flux_base_model_pth = config_['flux'].get('base_model', None)
|
68 |
+
flux_dtype = config_['flux'].get('dtype', 'bf16')
|
69 |
flux_controlnet_pth = config_['flux'].get('controlnet', None)
|
70 |
+
# flux_lora_pth = config_['flux'].get('lora', None)
|
71 |
+
flux_lora_pth = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
|
72 |
+
flux_redux_pth = config_['flux'].get('redux', None)
|
73 |
+
|
74 |
+
if flux_base_model_pth.endswith('safetensors'):
|
75 |
+
flux_pipe = FluxImg2ImgPipeline.from_single_file(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
|
|
|
76 |
else:
|
77 |
+
flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
|
78 |
|
79 |
+
# load flux model and controlnet
|
80 |
+
if flux_controlnet_pth is not None:
|
81 |
+
flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
|
82 |
+
flux_pipe = convert_flux_pipeline(flux_pipe, FluxControlNetImg2ImgPipeline, controlnet=[flux_controlnet])
|
83 |
+
|
84 |
+
flux_pipe.scheduler = FlowMatchHeunDiscreteScheduler.from_config(flux_pipe.scheduler.config)
|
85 |
+
|
86 |
# load lora weights
|
87 |
flux_pipe.load_lora_weights(flux_lora_pth)
|
88 |
+
flux_pipe.to(device=flux_device)
|
89 |
|
90 |
+
# load redux model
|
91 |
+
flux_redux_pipe = None
|
92 |
+
if flux_redux_pth is not None:
|
93 |
+
flux_redux_pipe = FluxPriorReduxPipeline.from_pretrained(flux_redux_pth, torch_dtype=torch.bfloat16)
|
94 |
+
flux_redux_pipe.text_encoder = flux_pipe.text_encoder
|
95 |
+
flux_redux_pipe.text_encoder_2 = flux_pipe.text_encoder_2
|
96 |
+
flux_redux_pipe.tokenizer = flux_pipe.tokenizer
|
97 |
+
flux_redux_pipe.tokenizer_2 = flux_pipe.tokenizer_2
|
98 |
+
|
99 |
+
flux_redux_pipe.to(device=flux_device)
|
100 |
+
|
101 |
+
logger.warning(f"GPU memory allocated after load flux model on {flux_device}: {torch.cuda.memory_allocated(device=flux_device) / 1024**3} GB")
|
102 |
|
103 |
# TODO: load pulid model
|
104 |
|
|
|
114 |
multiview_pipeline.scheduler.config, timestep_spacing='trailing'
|
115 |
)
|
116 |
|
117 |
+
# unet_ckpt_path = config_['multiview'].get('unet', None)
|
118 |
+
unet_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="flexgen_19w.ckpt", repo_type="model")
|
119 |
if unet_ckpt_path is not None:
|
120 |
state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
|
121 |
state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
|
122 |
multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
|
123 |
|
124 |
multiview_pipeline.to(multiview_device)
|
125 |
+
logger.warning(f"GPU memory allocated after load multiview model on {multiview_device}: {torch.cuda.memory_allocated(device=multiview_device) / 1024**3} GB")
|
126 |
|
127 |
# load caption model
|
128 |
logger.info('==> Loading caption model ...')
|
|
|
130 |
caption_model = AutoModelForCausalLM.from_pretrained(config_['caption']['base_model'], \
|
131 |
torch_dtype=torch.bfloat16, trust_remote_code=True).to(caption_device)
|
132 |
caption_processor = AutoProcessor.from_pretrained(config_['caption']['base_model'], trust_remote_code=True)
|
133 |
+
logger.warning(f"GPU memory allocated after load caption model on {caption_device}: {torch.cuda.memory_allocated(device=caption_device) / 1024**3} GB")
|
134 |
|
135 |
# load reconstruction model
|
136 |
logger.info('==> Loading reconstruction model ...')
|
|
|
138 |
recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
|
139 |
recon_model = instantiate_from_config(recon_model_config.model_config)
|
140 |
# load recon model checkpoint
|
141 |
+
model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
142 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
143 |
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
144 |
recon_model.load_state_dict(state_dict, strict=True)
|
145 |
recon_model.to(recon_device)
|
146 |
recon_model.init_flexicubes_geometry(recon_device, fovy=50.0)
|
147 |
recon_model.eval()
|
148 |
+
logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
|
149 |
+
|
150 |
+
# load llm
|
151 |
+
llm_configs = config_.get('llm', None)
|
152 |
+
if llm_configs is not None:
|
153 |
+
logger.info('==> Loading LLM ...')
|
154 |
+
llm_device = llm_configs.get('device', 'cpu')
|
155 |
+
llm, llm_tokenizer = load_llm_model(llm_configs['base_model'])
|
156 |
+
llm.to(llm_device)
|
157 |
+
logger.warning(f"GPU memory allocated after load llm model on {llm_device}: {torch.cuda.memory_allocated(device=llm_device) / 1024**3} GB")
|
158 |
+
else:
|
159 |
+
llm, llm_tokenizer = None, None
|
160 |
|
161 |
return kiss3d_wrapper(
|
162 |
config = config_,
|
163 |
flux_pipeline = flux_pipe,
|
164 |
+
flux_redux_pipeline=flux_redux_pipe,
|
165 |
multiview_pipeline = multiview_pipeline,
|
166 |
caption_processor = caption_processor,
|
167 |
caption_model = caption_model,
|
168 |
reconstruction_model_config = recon_model_config,
|
169 |
reconstruction_model = recon_model,
|
170 |
+
llm_model = llm,
|
171 |
+
llm_tokenizer = llm_tokenizer
|
172 |
)
|
173 |
|
174 |
+
def seed_everything(seed):
|
175 |
+
|
176 |
+
random.seed(seed)
|
177 |
+
np.random.seed(seed)
|
178 |
+
torch.manual_seed(seed)
|
179 |
+
torch.cuda.manual_seed(seed)
|
180 |
+
torch.cuda.manual_seed_all(seed)
|
181 |
+
torch.backends.cudnn.deterministic = True
|
182 |
+
torch.backends.cudnn.benchmark = False
|
183 |
+
|
184 |
+
print(f"Random seed set to {seed}")
|
185 |
+
|
186 |
class kiss3d_wrapper(object):
|
187 |
def __init__(self,
|
188 |
config: Dict,
|
189 |
flux_pipeline: Union[FluxPipeline, FluxControlNetImg2ImgPipeline],
|
190 |
+
flux_redux_pipeline: FluxPriorReduxPipeline,
|
191 |
multiview_pipeline: DiffusionPipeline,
|
192 |
caption_processor: AutoProcessor,
|
193 |
caption_model: AutoModelForCausalLM,
|
194 |
reconstruction_model_config: Any,
|
195 |
reconstruction_model: Any,
|
196 |
+
llm_model: AutoModelForCausalLM = None,
|
197 |
+
llm_tokenizer: AutoTokenizer = None
|
198 |
):
|
199 |
self.config = config
|
200 |
self.flux_pipeline = flux_pipeline
|
201 |
+
self.flux_redux_pipeline = flux_redux_pipeline
|
202 |
self.multiview_pipeline = multiview_pipeline
|
203 |
self.caption_model = caption_model
|
204 |
self.caption_processor = caption_processor
|
205 |
self.recon_model_config = reconstruction_model_config
|
206 |
+
self.recon_model = reconstruction_model
|
207 |
+
self.llm_model = llm_model
|
208 |
+
self.llm_tokenizer = llm_tokenizer
|
209 |
+
|
210 |
+
self.to_512_tensor = torchvision.transforms.Compose([
|
211 |
+
torchvision.transforms.ToTensor(),
|
212 |
+
torchvision.transforms.Resize((512, 512), interpolation=2),
|
213 |
+
])
|
214 |
|
215 |
self.renew_uuid()
|
216 |
|
|
|
232 |
caption_device = self.config['caption'].get('device', 'cpu')
|
233 |
|
234 |
if isinstance(image, str): # If image is a file path
|
235 |
+
image = preprocess_input_image(Image.open(image))
|
236 |
+
elif not isinstance(image, Image.Image):
|
|
|
|
|
237 |
raise NotImplementedError('unexpected image type')
|
238 |
+
|
239 |
prompt = "<MORE_DETAILED_CAPTION>"
|
240 |
inputs = self.caption_processor(text=prompt, images=image, return_tensors="pt").to(caption_device, torch_dtype)
|
241 |
|
|
|
247 |
parsed_answer = self.caption_processor.post_process_generation(
|
248 |
generated_text, task=prompt, image_size=(image.width, image.height)
|
249 |
)
|
250 |
+
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"] # .replace("The image is ", "")
|
251 |
+
|
252 |
+
logger.info(f"Auto caption result: \"{caption_text}\"")
|
253 |
+
|
254 |
+
caption_text = self.get_detailed_prompt(caption_text)
|
255 |
+
|
256 |
return caption_text
|
257 |
|
258 |
+
def get_detailed_prompt(self, prompt, seed=None):
|
259 |
+
if self.llm_model is not None:
|
260 |
+
detailed_prompt = get_llm_response(self.llm_model, self.llm_tokenizer, prompt, seed=seed)
|
261 |
+
|
262 |
+
logger.info(f"LLM refined prompt result: \"{detailed_prompt}\"")
|
263 |
+
return detailed_prompt
|
264 |
+
return prompt
|
265 |
+
|
266 |
+
def del_llm_model(self):
|
267 |
+
logger.warning('This function is now deprecated and will take no effect')
|
268 |
+
|
269 |
+
# raise NotImplementedError()
|
270 |
+
# del llm.model
|
271 |
+
# del llm.tokenizer
|
272 |
+
# llm.model = None
|
273 |
+
# llm.tokenizer = None
|
274 |
+
|
275 |
+
def generate_multiview(self, image, seed=None, num_inference_steps=None):
|
276 |
+
seed = seed or self.config['multiview'].get('seed', 0)
|
277 |
+
mv_device = self.config['multiview'].get('device', 'cpu')
|
278 |
+
|
279 |
+
generator = torch.Generator(device=mv_device).manual_seed(seed)
|
280 |
with self.context():
|
281 |
mv_image = self.multiview_pipeline(image,
|
282 |
+
num_inference_steps=num_inference_steps or self.config['multiview']['num_inference_steps'],
|
283 |
+
width=512*2,
|
284 |
+
height=512*2,
|
285 |
+
generator=generator).images[0]
|
286 |
return mv_image
|
287 |
|
288 |
+
def reconstruct_from_multiview(self, mv_image, lrm_render_radius=4.15):
|
289 |
"""
|
290 |
mv_image: PIL.Image
|
291 |
"""
|
|
|
298 |
with self.context():
|
299 |
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \
|
300 |
lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config,
|
301 |
+
rgb_multi_view, name=self.uuid, render_radius=lrm_render_radius)
|
302 |
|
303 |
+
return rgb_multi_view, vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo
|
304 |
|
305 |
+
def generate_reference_3D_bundle_image_zero123(self, image, use_mv_rgb=False, save_intermediate_results=True):
|
306 |
"""
|
307 |
input: image, PIL.Image
|
308 |
+
return: ref_3D_bundle_image, Tensor of shape (3, 1024, 2048)
|
309 |
"""
|
310 |
mv_image = self.generate_multiview(image)
|
311 |
|
312 |
if save_intermediate_results:
|
313 |
mv_image.save(os.path.join(TMP_DIR, f'{self.uuid}_mv_image.png'))
|
314 |
|
315 |
+
rgb_multi_view, vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = self.reconstruct_from_multiview(mv_image)
|
316 |
+
|
317 |
+
if use_mv_rgb:
|
318 |
+
# ref_3D_bundle_image = torchvision.utils.make_grid(torch.cat([rgb_multi_view[0, [3, 0, 1, 2], ...].cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0) # range [0, 1]
|
319 |
|
320 |
+
rgb_ = torch.cat([rgb_multi_view[0, [3, 0, 1, 2], ...].cpu(), lrm_multi_view_rgb.cpu()], dim=0)
|
321 |
+
ref_3D_bundle_image = torchvision.utils.make_grid(torch.cat([rgb_[[0, 5, 2, 7], ...], (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0) # range [0, 1]
|
322 |
+
else:
|
323 |
+
ref_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0) # range [0, 1]
|
324 |
+
|
325 |
+
ref_3D_bundle_image = ref_3D_bundle_image.clip(0., 1.)
|
326 |
|
327 |
if save_intermediate_results:
|
328 |
save_path = os.path.join(TMP_DIR, f'{self.uuid}_ref_3d_bundle_image.png')
|
|
|
344 |
control_guidance_end=None,
|
345 |
controlnet_conditioning_scale=None,
|
346 |
lora_scale=1.0,
|
347 |
+
num_inference_steps=None,
|
348 |
+
seed=None,
|
349 |
+
redux_hparam=None,
|
350 |
save_intermediate_results=True,
|
351 |
**kwargs):
|
352 |
control_mode_dict = {
|
|
|
360 |
} # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
|
361 |
|
362 |
flux_device = self.config['flux'].get('device', 'cpu')
|
363 |
+
seed = seed or self.config['flux'].get('seed', 0)
|
364 |
+
num_inference_steps = num_inference_steps or self.config['flux'].get('num_inference_steps', 20)
|
365 |
|
366 |
generator = torch.Generator(device=flux_device).manual_seed(seed)
|
367 |
|
368 |
+
if image is None:
|
369 |
+
image = torch.zeros((1, 3, 1024, 2048), dtype=torch.float32, device=flux_device)
|
370 |
+
|
371 |
hparam_dict = {
|
372 |
+
'prompt': 'A grid of 2x4 multi-view image, elevation 5. White background.',
|
373 |
+
'prompt_2': ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', prompt]),
|
374 |
+
'image': image,
|
375 |
'strength': strength,
|
376 |
+
'num_inference_steps': num_inference_steps,
|
377 |
'guidance_scale': 3.5,
|
378 |
'num_images_per_prompt': 1,
|
379 |
'width': 2048,
|
|
|
383 |
'joint_attention_kwargs': {"scale": lora_scale}
|
384 |
}
|
385 |
hparam_dict.update(kwargs)
|
386 |
+
|
387 |
+
# do redux
|
388 |
+
if redux_hparam is not None:
|
389 |
+
assert self.flux_redux_pipeline is not None
|
390 |
+
assert 'image' in redux_hparam.keys()
|
391 |
+
redux_hparam_ = {
|
392 |
+
'prompt': hparam_dict.pop('prompt'),
|
393 |
+
'prompt_2': hparam_dict.pop('prompt_2'),
|
394 |
+
}
|
395 |
+
redux_hparam_.update(redux_hparam)
|
396 |
+
|
397 |
+
with self.context():
|
398 |
+
redux_output = self.flux_redux_pipeline(**redux_hparam_)
|
399 |
+
|
400 |
+
hparam_dict.update(redux_output)
|
401 |
+
|
402 |
# append controlnet hparams
|
403 |
if len(control_image) > 0:
|
404 |
assert isinstance(self.flux_pipeline, FluxControlNetImg2ImgPipeline)
|
405 |
assert len(control_mode) == len(control_image) # the count of image should be the same as control mode
|
406 |
|
407 |
flux_ctrl_net = self.flux_pipeline.controlnet.nets[0]
|
408 |
+
self.flux_pipeline.controlnet = FluxMultiControlNetModel([flux_ctrl_net for _ in control_mode])
|
409 |
|
410 |
ctrl_hparams = {
|
411 |
'control_mode': [control_mode_dict[mode_] for mode_ in control_mode],
|
|
|
430 |
|
431 |
return gen_3d_bundle_image_
|
432 |
|
433 |
+
def preprocess_controlnet_cond_image(self, image, control_mode, save_intermediate_results=True, **kwargs):
|
434 |
+
"""
|
435 |
+
image: Tensor of shape (c, h, w), range [0., 1.]
|
436 |
+
"""
|
437 |
+
if control_mode in ['tile', 'lq']:
|
438 |
+
_, h, w = image.shape
|
439 |
+
down_scale = kwargs.get('down_scale', 4)
|
440 |
+
down_up = torchvision.transforms.Compose([
|
441 |
+
torchvision.transforms.Resize((h // down_scale, w // down_scale), interpolation=2), # 1 for lanczos and 2 for bilinear
|
442 |
+
torchvision.transforms.Resize((h, w), interpolation=2),
|
443 |
+
torchvision.transforms.ToPILImage()
|
444 |
+
])
|
445 |
+
preprocessed = down_up(image)
|
446 |
+
elif control_mode == 'blur':
|
447 |
+
kernel_size = kwargs.get('kernel_size', 51)
|
448 |
+
sigma = kwargs.get('sigma', 2.0)
|
449 |
+
blur = torchvision.transforms.Compose([
|
450 |
+
torchvision.transforms.ToPILImage(),
|
451 |
+
torchvision.transforms.GaussianBlur(kernel_size, sigma),
|
452 |
+
])
|
453 |
+
preprocessed = blur(image)
|
454 |
+
else:
|
455 |
+
raise NotImplementedError(f'Unexpected control mode {control_mode}')
|
456 |
+
|
457 |
+
if save_intermediate_results:
|
458 |
+
save_path = os.path.join(TMP_DIR, f'{self.uuid}_{control_mode}_controlnet_cond.png')
|
459 |
+
preprocessed.save(save_path)
|
460 |
+
logger.info(f'Save image to {save_path}')
|
461 |
+
|
462 |
+
return preprocessed
|
463 |
|
464 |
def generate_3d_bundle_image_text(self,
|
465 |
prompt,
|
466 |
image=None,
|
467 |
strength=1.0,
|
468 |
lora_scale=1.0,
|
469 |
+
num_inference_steps=None,
|
470 |
+
seed=None,
|
471 |
+
redux_hparam=None,
|
472 |
save_intermediate_results=True,
|
473 |
**kwargs):
|
474 |
|
|
|
476 |
return: gen_3d_bundle_image, torch.Tensor of shape (3, 1024, 2048), range [0., 1.]
|
477 |
"""
|
478 |
|
479 |
+
if isinstance(self.flux_pipeline, FluxImg2ImgPipeline):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
flux_pipeline = self.flux_pipeline
|
481 |
+
else:
|
482 |
+
flux_pipeline = convert_flux_pipeline(self.flux_pipeline, FluxImg2ImgPipeline)
|
483 |
|
484 |
flux_device = self.config['flux'].get('device', 'cpu')
|
485 |
+
seed = seed or self.config['flux'].get('seed', 0)
|
486 |
+
num_inference_steps = num_inference_steps or self.config['flux'].get('num_inference_steps', 20)
|
487 |
+
|
488 |
+
if image is None:
|
489 |
+
image = torch.zeros((1, 3, 1024, 2048), dtype=torch.float32, device=flux_device)
|
490 |
|
491 |
generator = torch.Generator(device=flux_device).manual_seed(seed)
|
492 |
|
493 |
+
|
494 |
hparam_dict = {
|
495 |
+
'prompt': 'A grid of 2x4 multi-view image, elevation 5. White background.',
|
496 |
+
'prompt_2': ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', prompt]),
|
497 |
+
'image': image,
|
498 |
'strength': strength,
|
499 |
'num_inference_steps': num_inference_steps,
|
500 |
'guidance_scale': 3.5,
|
|
|
507 |
}
|
508 |
hparam_dict.update(kwargs)
|
509 |
|
510 |
+
# do redux
|
511 |
+
if redux_hparam is not None:
|
512 |
+
assert self.flux_redux_pipeline is not None
|
513 |
+
assert 'image' in redux_hparam.keys()
|
514 |
+
redux_hparam_ = {
|
515 |
+
'prompt': hparam_dict.pop('prompt'),
|
516 |
+
'prompt_2': hparam_dict.pop('prompt_2'),
|
517 |
+
}
|
518 |
+
redux_hparam_.update(redux_hparam)
|
519 |
+
|
520 |
+
with self.context():
|
521 |
+
redux_output = self.flux_redux_pipeline(**redux_hparam_)
|
522 |
+
|
523 |
+
hparam_dict.update(redux_output)
|
524 |
+
|
525 |
+
|
526 |
with self.context():
|
527 |
gen_3d_bundle_image = flux_pipeline(**hparam_dict).images
|
528 |
|
|
|
536 |
|
537 |
return gen_3d_bundle_image_
|
538 |
|
539 |
+
def reconstruct_3d_bundle_image(self,
|
540 |
+
image,
|
541 |
+
lrm_render_radius=4.15,
|
542 |
+
isomer_radius=4.5,
|
543 |
+
reconstruction_stage1_steps=0,
|
544 |
+
reconstruction_stage2_steps=20,
|
545 |
+
save_intermediate_results=True):
|
546 |
"""
|
547 |
image: torch.Tensor, range [0., 1.], (3, 1024, 2048)
|
548 |
"""
|
|
|
552 |
images = rearrange(image, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (3, 1024, 2048) -> (8, 3, 512, 512)
|
553 |
rgb_multi_view, normal_multi_view = images.chunk(2, dim=0)
|
554 |
multi_view_mask = get_background(normal_multi_view).to(recon_device)
|
555 |
+
print(f'shape images: {images.shape}')
|
556 |
+
# breakpoint()
|
557 |
rgb_multi_view = rgb_multi_view.to(recon_device) * multi_view_mask + (1 - multi_view_mask)
|
558 |
|
559 |
with self.context():
|
|
|
561 |
lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config,
|
562 |
rgb_multi_view.unsqueeze(0).to(recon_device), name=self.uuid,
|
563 |
input_camera_type='kiss3d', render_3d_bundle_image=save_intermediate_results,
|
564 |
+
render_azimuths=[0, 90, 180, 270],
|
565 |
+
render_radius=lrm_render_radius)
|
566 |
|
567 |
if save_intermediate_results:
|
568 |
recon_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
|
569 |
+
torchvision.utils.save_image(recon_3D_bundle_image, os.path.join(TMP_DIR, f'{self.uuid}_lrm_recon_3d_bundle_image.png'))
|
570 |
|
571 |
recon_mesh_path = os.path.join(TMP_DIR, f"{self.uuid}_isomer_recon_mesh.obj")
|
572 |
|
|
|
575 |
multi_view_mask=multi_view_mask,
|
576 |
vertices=vertices,
|
577 |
faces=faces,
|
578 |
+
save_path=recon_mesh_path,
|
579 |
+
radius=isomer_radius,
|
580 |
+
reconstruction_stage1_steps=int(reconstruction_stage1_steps),
|
581 |
+
reconstruction_stage2_steps=int(reconstruction_stage2_steps)
|
582 |
+
)
|
583 |
|
584 |
|
585 |
def run_text_to_3d(k3d_wrapper,
|
|
|
595 |
if init_image_path is not None:
|
596 |
init_image = Image.open(init_image_path)
|
597 |
|
598 |
+
# refine prompt
|
599 |
+
logger.info(f"Input prompt: \"{prompt}\"")
|
600 |
+
|
601 |
+
prompt = k3d_wrapper.get_detailed_prompt(prompt)
|
602 |
+
|
603 |
gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_text(prompt,
|
604 |
+
image=init_image,
|
605 |
+
strength=1.0,
|
606 |
+
save_intermediate_results=True)
|
607 |
+
|
608 |
+
# recon from 3D Bundle image
|
609 |
+
recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, save_intermediate_results=False)
|
610 |
+
|
611 |
+
return gen_save_path, recon_mesh_path
|
612 |
+
|
613 |
+
def image2mesh_preprocess(k3d_wrapper, input_image_, seed, use_mv_rgb=True):
|
614 |
+
seed_everything(seed)
|
615 |
+
|
616 |
+
# Renew The uuid
|
617 |
+
k3d_wrapper.renew_uuid()
|
618 |
+
|
619 |
+
# FOR IMAGE TO 3D: generate reference 3D bundle image from a single input image
|
620 |
+
input_image__ = Image.open(input_image_) if isinstance(input_image_, str) else input_image_
|
621 |
+
|
622 |
+
input_image = preprocess_input_image(input_image__)
|
623 |
+
input_image_save_path = os.path.join(TMP_DIR, f'{k3d_wrapper.uuid}_input_image.png')
|
624 |
+
input_image.save(input_image_save_path)
|
625 |
+
|
626 |
+
reference_3d_bundle_image, reference_save_path = k3d_wrapper.generate_reference_3D_bundle_image_zero123(input_image, use_mv_rgb=use_mv_rgb)
|
627 |
+
caption = k3d_wrapper.get_image_caption(input_image)
|
628 |
+
|
629 |
+
return input_image_save_path, reference_save_path, caption
|
630 |
+
|
631 |
+
def image2mesh_main(k3d_wrapper, input_image, reference_3d_bundle_image, caption, seed, strength1=0.5, strength2=0.95, enable_redux=True, use_controlnet=True):
|
632 |
+
seed_everything(seed)
|
633 |
+
|
634 |
+
if enable_redux:
|
635 |
+
redux_hparam = {
|
636 |
+
'image': k3d_wrapper.to_512_tensor(input_image).unsqueeze(0).clip(0., 1.),
|
637 |
+
'prompt_embeds_scale': 1.0,
|
638 |
+
'pooled_prompt_embeds_scale': 1.0,
|
639 |
+
'strength': strength1
|
640 |
+
}
|
641 |
+
else:
|
642 |
+
redux_hparam = None
|
643 |
+
|
644 |
+
if use_controlnet:
|
645 |
+
# prepare controlnet condition
|
646 |
+
control_mode = ['tile']
|
647 |
+
control_image = [k3d_wrapper.preprocess_controlnet_cond_image(reference_3d_bundle_image, mode_, down_scale=1, kernel_size=51, sigma=2.0) for mode_ in control_mode]
|
648 |
+
control_guidance_start = [0.0]
|
649 |
+
control_guidance_end = [0.3]
|
650 |
+
controlnet_conditioning_scale = [0.3]
|
651 |
+
|
652 |
+
|
653 |
+
gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_controlnet(
|
654 |
+
prompt=caption,
|
655 |
+
image=reference_3d_bundle_image.unsqueeze(0),
|
656 |
+
strength=strength2,
|
657 |
+
control_image=control_image,
|
658 |
+
control_mode=control_mode,
|
659 |
+
control_guidance_start=control_guidance_start,
|
660 |
+
control_guidance_end=control_guidance_end,
|
661 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
662 |
+
lora_scale=1.0,
|
663 |
+
redux_hparam=redux_hparam
|
664 |
+
)
|
665 |
+
else:
|
666 |
+
gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_text(
|
667 |
+
prompt=caption,
|
668 |
+
image=reference_3d_bundle_image.unsqueeze(0),
|
669 |
+
strength=strength2,
|
670 |
+
lora_scale=1.0,
|
671 |
+
redux_hparam=redux_hparam
|
672 |
+
)
|
673 |
|
674 |
# recon from 3D Bundle image
|
675 |
recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, save_intermediate_results=False)
|
676 |
|
677 |
return gen_save_path, recon_mesh_path
|
678 |
|
679 |
+
|
680 |
+
def run_image_to_3d(k3d_wrapper, input_image_path, enable_redux=True, use_mv_rgb=True, use_controlnet=True):
|
681 |
# ======================================= Example of image to 3D generation ======================================
|
682 |
|
683 |
# Renew The uuid
|
684 |
k3d_wrapper.renew_uuid()
|
685 |
|
686 |
# FOR IMAGE TO 3D: generate reference 3D bundle image from a single input image
|
687 |
+
input_image = preprocess_input_image(Image.open(input_image_path))
|
688 |
+
input_image.save(os.path.join(TMP_DIR, f'{k3d_wrapper.uuid}_input_image.png'))
|
689 |
+
|
690 |
+
reference_3d_bundle_image, reference_save_path = k3d_wrapper.generate_reference_3D_bundle_image_zero123(input_image, use_mv_rgb=use_mv_rgb)
|
691 |
caption = k3d_wrapper.get_image_caption(input_image)
|
692 |
|
693 |
+
if enable_redux:
|
694 |
+
redux_hparam = {
|
695 |
+
'image': k3d_wrapper.to_512_tensor(input_image).unsqueeze(0).clip(0., 1.),
|
696 |
+
'prompt_embeds_scale': 1.0,
|
697 |
+
'pooled_prompt_embeds_scale': 1.0,
|
698 |
+
'strength': 0.5
|
699 |
+
}
|
700 |
+
else:
|
701 |
+
redux_hparam = None
|
702 |
+
|
703 |
+
if use_controlnet:
|
704 |
+
# prepare controlnet condition
|
705 |
+
control_mode = ['tile']
|
706 |
+
control_image = [k3d_wrapper.preprocess_controlnet_cond_image(reference_3d_bundle_image, mode_, down_scale=1, kernel_size=51, sigma=2.0) for mode_ in control_mode]
|
707 |
+
control_guidance_start = [0.0]
|
708 |
+
control_guidance_end = [0.3]
|
709 |
+
controlnet_conditioning_scale = [0.3]
|
710 |
+
|
711 |
+
|
712 |
+
gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_controlnet(
|
713 |
+
prompt=caption,
|
714 |
+
image=reference_3d_bundle_image.unsqueeze(0),
|
715 |
+
strength=.95,
|
716 |
+
control_image=control_image,
|
717 |
+
control_mode=control_mode,
|
718 |
+
control_guidance_start=control_guidance_start,
|
719 |
+
control_guidance_end=control_guidance_end,
|
720 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
721 |
+
lora_scale=1.0,
|
722 |
+
redux_hparam=redux_hparam
|
723 |
+
)
|
724 |
+
else:
|
725 |
+
gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_text(
|
726 |
+
prompt=caption,
|
727 |
+
image=reference_3d_bundle_image.unsqueeze(0),
|
728 |
+
strength=.95,
|
729 |
+
lora_scale=1.0,
|
730 |
+
redux_hparam=redux_hparam
|
731 |
+
)
|
732 |
+
|
733 |
+
# recon from 3D Bundle image
|
734 |
+
recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, save_intermediate_results=False)
|
735 |
|
736 |
+
return gen_save_path, recon_mesh_path
|
|
|
737 |
|
738 |
|
739 |
if __name__ == "__main__":
|
740 |
+
k3d_wrapper = init_wrapper_from_config('/hpc2hdd/home/jlin695/code/github/Kiss3DGen/pipeline/pipeline_config/default.yaml')
|
741 |
+
|
742 |
+
os.system(f'rm -rf {TMP_DIR}/*')
|
743 |
+
# os.system(f'rm -rf {OUT_DIR}/3d_bundle/*')
|
744 |
|
745 |
+
enable_redux = True
|
746 |
+
use_mv_rgb = True
|
747 |
+
use_controlnet = True
|
748 |
|
749 |
+
img_folder = '/hpc2hdd/home/jlin695/code/Kiss3DGen/examples'
|
750 |
+
for img_ in os.listdir(img_folder):
|
751 |
+
name, _ = os.path.splitext(img_)
|
752 |
+
print("Now processing:", name)
|
753 |
+
|
754 |
+
gen_save_path, recon_mesh_path = run_image_to_3d(k3d_wrapper, os.path.join(img_folder, img_), enable_redux, use_mv_rgb, use_controlnet)
|
755 |
+
|
756 |
+
os.system(f'cp -f {gen_save_path} {OUT_DIR}/3d_bundle/{name}_3d_bundle.png')
|
757 |
+
os.system(f'cp -f {recon_mesh_path} {OUT_DIR}/3d_bundle/{name}.obj')
|
758 |
+
|
759 |
+
# TODO exams:
|
760 |
+
# 1. redux True, mv_rgb False, Tile, down_scale = 1
|
761 |
+
# 2. redux False, mv_rgb True, Tile, down_scale = 8
|
762 |
+
# 3. redux False, mv_rgb False, Tile, blur = 10
|
763 |
+
|
764 |
+
|
765 |
# run_text_to_3d(k3d_wrapper, prompt='A doll of a girl in Harry Potter')
|
766 |
|
767 |
+
|
768 |
+
# Example of loading existing 3D bundle Image as Tensor from path
|
769 |
+
# pseudo_image = Image.open('/hpc2hdd/home/jlin695/code/github/Kiss3DGen/outputs/tmp/fbf6edad-2d7f-49e5-8ac2-a05af5fe695b_ref_3d_bundle_image.png')
|
770 |
+
# gen_3d_bundle_image = torchvision.transforms.functional.to_tensor(pseudo_image)
|
pipeline/pipeline_config/default.yaml
CHANGED
@@ -1,15 +1,19 @@
|
|
1 |
flux:
|
2 |
-
base_model: "/
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
6 |
device: 'cuda:0'
|
7 |
|
8 |
multiview:
|
9 |
base_model: "sudo-ai/zero123plus-v1.2"
|
10 |
custom_pipeline: "./models/zero123plus"
|
11 |
unet: "./checkpoint/zero123++/flexgen_19w.ckpt"
|
12 |
-
num_inference_steps:
|
|
|
13 |
device: 'cuda:0'
|
14 |
|
15 |
reconstruction:
|
@@ -18,8 +22,12 @@ reconstruction:
|
|
18 |
device: 'cuda:0'
|
19 |
|
20 |
caption:
|
21 |
-
base_model: "/
|
22 |
-
device: 'cuda:
|
|
|
|
|
|
|
|
|
23 |
|
24 |
use_zero_gpu: false # for huggingface demo only
|
25 |
-
3d_bundle_templates: '
|
|
|
1 |
flux:
|
2 |
+
base_model: "https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"
|
3 |
+
flux_dtype: 'fp8'
|
4 |
+
lora: "./checkpoint/flux_lora/rgb_normal_large.safetensors"
|
5 |
+
controlnet: "InstantX/FLUX.1-dev-Controlnet-Union"
|
6 |
+
redux: "black-forest-labs/FLUX.1-Redux-dev"
|
7 |
+
num_inference_steps: 20
|
8 |
+
seed: 42
|
9 |
device: 'cuda:0'
|
10 |
|
11 |
multiview:
|
12 |
base_model: "sudo-ai/zero123plus-v1.2"
|
13 |
custom_pipeline: "./models/zero123plus"
|
14 |
unet: "./checkpoint/zero123++/flexgen_19w.ckpt"
|
15 |
+
num_inference_steps: 50
|
16 |
+
seed: 42
|
17 |
device: 'cuda:0'
|
18 |
|
19 |
reconstruction:
|
|
|
22 |
device: 'cuda:0'
|
23 |
|
24 |
caption:
|
25 |
+
base_model: "multimodalart/Florence-2-large-no-flash-attn"
|
26 |
+
device: 'cuda:1'
|
27 |
+
|
28 |
+
llm:
|
29 |
+
base_model: "Qwen/Qwen2-7B-Instruct"
|
30 |
+
device: 'cuda:1'
|
31 |
|
32 |
use_zero_gpu: false # for huggingface demo only
|
33 |
+
3d_bundle_templates: './init_3d_Bundle'
|
pipeline/run_hpc.sh
CHANGED
@@ -5,6 +5,6 @@ export CC=$(which gcc)
|
|
5 |
export CPLUS_INCLUDE_PATH=/hpc2ssd/softwares/cuda/cuda-12.1/targets/x86_64-linux/include:$CPLUS_INCLUDE_PATH
|
6 |
export CUDA_LAUNCH_BLOCKING=1
|
7 |
export NCCL_TIMEOUT=3600
|
8 |
-
export CUDA_VISIBLE_DEVICES="0"
|
9 |
|
10 |
python ./pipeline/kiss3d_wrapper.py
|
|
|
5 |
export CPLUS_INCLUDE_PATH=/hpc2ssd/softwares/cuda/cuda-12.1/targets/x86_64-linux/include:$CPLUS_INCLUDE_PATH
|
6 |
export CUDA_LAUNCH_BLOCKING=1
|
7 |
export NCCL_TIMEOUT=3600
|
8 |
+
export CUDA_VISIBLE_DEVICES="0,1"
|
9 |
|
10 |
python ./pipeline/kiss3d_wrapper.py
|
run_hpc.sh → pipeline/run_hpc_text_to_3d.sh
RENAMED
@@ -5,7 +5,6 @@ export CC=$(which gcc)
|
|
5 |
export CPLUS_INCLUDE_PATH=/hpc2ssd/softwares/cuda/cuda-12.1/targets/x86_64-linux/include:$CPLUS_INCLUDE_PATH
|
6 |
export CUDA_LAUNCH_BLOCKING=1
|
7 |
export NCCL_TIMEOUT=3600
|
8 |
-
export CUDA_VISIBLE_DEVICES="0"
|
9 |
-
|
10 |
-
python
|
11 |
-
# python image_to_mesh.py
|
|
|
5 |
export CPLUS_INCLUDE_PATH=/hpc2ssd/softwares/cuda/cuda-12.1/targets/x86_64-linux/include:$CPLUS_INCLUDE_PATH
|
6 |
export CUDA_LAUNCH_BLOCKING=1
|
7 |
export NCCL_TIMEOUT=3600
|
8 |
+
export CUDA_VISIBLE_DEVICES="0,1"
|
9 |
+
|
10 |
+
python ./pipeline/example_text_to_3d.py
|
|
pipeline/utils.py
CHANGED
@@ -10,18 +10,20 @@ print(__workdir__)
|
|
10 |
import numpy as np
|
11 |
import torch
|
12 |
from torchvision.transforms import v2
|
|
|
|
|
13 |
|
14 |
from models.lrm.online_render.render_single import load_mipmap
|
15 |
from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
|
16 |
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
17 |
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
|
|
|
18 |
|
19 |
from models.ISOMER.reconstruction_func import reconstruction
|
20 |
from models.ISOMER.projection_func import projection
|
21 |
|
22 |
from utils.tool import NormalTransfer, get_render_cameras_frames, get_background, get_render_cameras_video, render_frames, mask_fix
|
23 |
|
24 |
-
|
25 |
logging.basicConfig(
|
26 |
level = logging.INFO
|
27 |
)
|
@@ -38,7 +40,7 @@ def lrm_reconstruct(model, infer_config, images,
|
|
38 |
render_3d_bundle_image=True,
|
39 |
render_azimuths=[270, 0, 90, 180],
|
40 |
render_elevations=[5, 5, 5, 5],
|
41 |
-
render_radius=4.
|
42 |
"""
|
43 |
image: Tensor, shape (1, c, h, w)
|
44 |
"""
|
@@ -49,7 +51,7 @@ def lrm_reconstruct(model, infer_config, images,
|
|
49 |
if input_camera_type == 'zero123':
|
50 |
input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
|
51 |
elif input_camera_type == 'kiss3d':
|
52 |
-
input_cameras = get_flux_input_cameras(batch_size=1, radius=
|
53 |
else:
|
54 |
raise NotImplementedError(f'Unexpected input camera type: {input_camera_type}')
|
55 |
|
@@ -142,9 +144,9 @@ def isomer_reconstruct(
|
|
142 |
elevations=[5, 5, 5, 5],
|
143 |
geo_weights=[1, 0.9, 1, 0.9],
|
144 |
color_weights=[1, 0.5, 1, 0.5],
|
145 |
-
reconstruction_stage1_steps=
|
146 |
reconstruction_stage2_steps=50,
|
147 |
-
radius=4.
|
148 |
|
149 |
device = rgb_multi_view.device
|
150 |
to_tensor_ = lambda x: torch.Tensor(x).float().to(device)
|
@@ -180,6 +182,7 @@ def isomer_reconstruct(
|
|
180 |
|
181 |
multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-10, blur=5)
|
182 |
|
|
|
183 |
logger.info(f"==> Runing ISOMER projection ...")
|
184 |
save_glb_addr = projection(
|
185 |
meshes,
|
@@ -195,4 +198,29 @@ def isomer_reconstruct(
|
|
195 |
)
|
196 |
|
197 |
logger.info(f"==> Save mesh to {save_glb_addr} ...")
|
198 |
-
return save_glb_addr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
import numpy as np
|
11 |
import torch
|
12 |
from torchvision.transforms import v2
|
13 |
+
from PIL import Image
|
14 |
+
import rembg
|
15 |
|
16 |
from models.lrm.online_render.render_single import load_mipmap
|
17 |
from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
|
18 |
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
19 |
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
|
20 |
+
from models.lrm.utils.infer_util import remove_background, resize_foreground
|
21 |
|
22 |
from models.ISOMER.reconstruction_func import reconstruction
|
23 |
from models.ISOMER.projection_func import projection
|
24 |
|
25 |
from utils.tool import NormalTransfer, get_render_cameras_frames, get_background, get_render_cameras_video, render_frames, mask_fix
|
26 |
|
|
|
27 |
logging.basicConfig(
|
28 |
level = logging.INFO
|
29 |
)
|
|
|
40 |
render_3d_bundle_image=True,
|
41 |
render_azimuths=[270, 0, 90, 180],
|
42 |
render_elevations=[5, 5, 5, 5],
|
43 |
+
render_radius=4.15):
|
44 |
"""
|
45 |
image: Tensor, shape (1, c, h, w)
|
46 |
"""
|
|
|
51 |
if input_camera_type == 'zero123':
|
52 |
input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
|
53 |
elif input_camera_type == 'kiss3d':
|
54 |
+
input_cameras = get_flux_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
|
55 |
else:
|
56 |
raise NotImplementedError(f'Unexpected input camera type: {input_camera_type}')
|
57 |
|
|
|
144 |
elevations=[5, 5, 5, 5],
|
145 |
geo_weights=[1, 0.9, 1, 0.9],
|
146 |
color_weights=[1, 0.5, 1, 0.5],
|
147 |
+
reconstruction_stage1_steps=10,
|
148 |
reconstruction_stage2_steps=50,
|
149 |
+
radius=4.5):
|
150 |
|
151 |
device = rgb_multi_view.device
|
152 |
to_tensor_ = lambda x: torch.Tensor(x).float().to(device)
|
|
|
182 |
|
183 |
multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-10, blur=5)
|
184 |
|
185 |
+
|
186 |
logger.info(f"==> Runing ISOMER projection ...")
|
187 |
save_glb_addr = projection(
|
188 |
meshes,
|
|
|
198 |
)
|
199 |
|
200 |
logger.info(f"==> Save mesh to {save_glb_addr} ...")
|
201 |
+
return save_glb_addr
|
202 |
+
|
203 |
+
|
204 |
+
def to_rgb_image(maybe_rgba):
|
205 |
+
assert isinstance(maybe_rgba, Image.Image)
|
206 |
+
if maybe_rgba.mode == 'RGB':
|
207 |
+
return maybe_rgba, None
|
208 |
+
elif maybe_rgba.mode == 'RGBA':
|
209 |
+
rgba = maybe_rgba
|
210 |
+
img = np.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8)
|
211 |
+
img = Image.fromarray(img, 'RGB')
|
212 |
+
img.paste(rgba, mask=rgba.getchannel('A'))
|
213 |
+
return img, rgba.getchannel('A')
|
214 |
+
else:
|
215 |
+
raise ValueError("Unsupported image type.", maybe_rgba.mode)
|
216 |
+
|
217 |
+
rembg_session = rembg.new_session("u2net")
|
218 |
+
def preprocess_input_image(input_image):
|
219 |
+
"""
|
220 |
+
input_image: PIL.Image
|
221 |
+
output_image: PIL.Image, (3, 512, 512), mode = RGB, background = white
|
222 |
+
"""
|
223 |
+
image = remove_background(to_rgb_image(input_image)[0], rembg_session, bgcolor=(255, 255, 255, 255))
|
224 |
+
image = resize_foreground(image, ratio=0.85, pad_value=255)
|
225 |
+
return to_rgb_image(image)[0]
|
226 |
+
|
run.sh
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
export CUDA_VISIBLE_DEVICES="0"
|
2 |
-
python text_to_mesh.py
|
|
|
|
|
|
text_to_mesh.py
DELETED
@@ -1,232 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from einops import rearrange
|
3 |
-
from omegaconf import OmegaConf
|
4 |
-
import torch
|
5 |
-
import numpy as np
|
6 |
-
import trimesh
|
7 |
-
import torchvision
|
8 |
-
import torch.nn.functional as F
|
9 |
-
from PIL import Image
|
10 |
-
from torchvision import transforms
|
11 |
-
from torchvision.transforms import v2
|
12 |
-
from diffusers import HeunDiscreteScheduler
|
13 |
-
from diffusers import FluxPipeline
|
14 |
-
from pytorch_lightning import seed_everything
|
15 |
-
import os
|
16 |
-
import time
|
17 |
-
from models.lrm.utils.infer_util import save_video
|
18 |
-
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
|
19 |
-
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
20 |
-
from models.lrm.utils.train_util import instantiate_from_config
|
21 |
-
from models.lrm.utils.camera_util import get_flux_input_cameras
|
22 |
-
from models.ISOMER.reconstruction_func import reconstruction
|
23 |
-
from models.ISOMER.projection_func import projection
|
24 |
-
from utils.tool import NormalTransfer, load_mipmap
|
25 |
-
from utils.tool import get_background, get_render_cameras_video, render_frames
|
26 |
-
|
27 |
-
device = "cuda"
|
28 |
-
resolution = 512
|
29 |
-
save_dir = "./outputs"
|
30 |
-
normal_transfer = NormalTransfer()
|
31 |
-
isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device)
|
32 |
-
isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device)
|
33 |
-
isomer_radius = 4.5
|
34 |
-
isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
|
35 |
-
isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
|
36 |
-
|
37 |
-
# model initialization and loading
|
38 |
-
# flux
|
39 |
-
flux_pipe = FluxPipeline.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev", torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
|
40 |
-
flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors')
|
41 |
-
|
42 |
-
flux_pipe.to(device=device, dtype=torch.bfloat16)
|
43 |
-
generator = torch.Generator(device=device).manual_seed(10)
|
44 |
-
|
45 |
-
# lrm
|
46 |
-
config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
|
47 |
-
model_config = config.model_config
|
48 |
-
infer_config = config.infer_config
|
49 |
-
model = instantiate_from_config(model_config)
|
50 |
-
model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt"
|
51 |
-
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
52 |
-
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
53 |
-
model.load_state_dict(state_dict, strict=True)
|
54 |
-
|
55 |
-
model = model.to(device)
|
56 |
-
model.init_flexicubes_geometry(device, fovy=50.0)
|
57 |
-
model = model.eval()
|
58 |
-
|
59 |
-
# Flux multi-view generation
|
60 |
-
def multi_view_rgb_normal_generation(prompt, save_path=None):
|
61 |
-
# generate multi-view images
|
62 |
-
with torch.no_grad():
|
63 |
-
image = flux_pipe(
|
64 |
-
prompt=prompt,
|
65 |
-
num_inference_steps=30,
|
66 |
-
guidance_scale=3.5,
|
67 |
-
num_images_per_prompt=1,
|
68 |
-
width=resolution*4,
|
69 |
-
height=resolution*2,
|
70 |
-
output_type='np',
|
71 |
-
generator=generator
|
72 |
-
).images
|
73 |
-
return image
|
74 |
-
|
75 |
-
# lrm reconstructions
|
76 |
-
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
|
77 |
-
images = image.unsqueeze(0).to(device)
|
78 |
-
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
|
79 |
-
# breakpoint()
|
80 |
-
with torch.no_grad():
|
81 |
-
# get triplane
|
82 |
-
planes = model.forward_planes(images, input_cameras)
|
83 |
-
|
84 |
-
mesh_path_idx = os.path.join(save_path, f'{name}.obj')
|
85 |
-
|
86 |
-
mesh_out = model.extract_mesh(
|
87 |
-
planes,
|
88 |
-
use_texture_map=export_texmap,
|
89 |
-
**infer_config,
|
90 |
-
)
|
91 |
-
if export_texmap:
|
92 |
-
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
93 |
-
save_obj_with_mtl(
|
94 |
-
vertices.data.cpu().numpy(),
|
95 |
-
uvs.data.cpu().numpy(),
|
96 |
-
faces.data.cpu().numpy(),
|
97 |
-
mesh_tex_idx.data.cpu().numpy(),
|
98 |
-
tex_map.permute(1, 2, 0).data.cpu().numpy(),
|
99 |
-
mesh_path_idx,
|
100 |
-
)
|
101 |
-
else:
|
102 |
-
vertices, faces, vertex_colors = mesh_out
|
103 |
-
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
|
104 |
-
print(f"Mesh saved to {mesh_path_idx}")
|
105 |
-
|
106 |
-
render_size = 512
|
107 |
-
if if_save_video:
|
108 |
-
video_path_idx = os.path.join(save_path, f'{name}.mp4')
|
109 |
-
render_size = infer_config.render_resolution
|
110 |
-
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
111 |
-
materials = (0.0,0.9)
|
112 |
-
|
113 |
-
all_mv, all_mvp, all_campos = get_render_cameras_video(
|
114 |
-
batch_size=1,
|
115 |
-
M=240,
|
116 |
-
radius=4.5,
|
117 |
-
elevation=(90, 60.0),
|
118 |
-
is_flexicubes=True,
|
119 |
-
fov=30
|
120 |
-
)
|
121 |
-
|
122 |
-
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
123 |
-
model,
|
124 |
-
planes,
|
125 |
-
render_cameras=all_mvp,
|
126 |
-
camera_pos=all_campos,
|
127 |
-
env=ENV,
|
128 |
-
materials=materials,
|
129 |
-
render_size=render_size,
|
130 |
-
chunk_size=20,
|
131 |
-
is_flexicubes=True,
|
132 |
-
)
|
133 |
-
normals = (torch.nn.functional.normalize(normals) + 1) / 2
|
134 |
-
normals = normals * alphas + (1-alphas)
|
135 |
-
all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
|
136 |
-
|
137 |
-
save_video(
|
138 |
-
all_frames,
|
139 |
-
video_path_idx,
|
140 |
-
fps=30,
|
141 |
-
)
|
142 |
-
print(f"Video saved to {video_path_idx}")
|
143 |
-
|
144 |
-
return vertices, faces
|
145 |
-
|
146 |
-
|
147 |
-
def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
|
148 |
-
if local_normal_images.min() >= 0:
|
149 |
-
local_normal = local_normal_images.float() * 2 - 1
|
150 |
-
else:
|
151 |
-
local_normal = local_normal_images.float()
|
152 |
-
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
|
153 |
-
global_normal[...,0] *= -1
|
154 |
-
global_normal = (global_normal + 1) / 2
|
155 |
-
global_normal = global_normal.permute(0, 3, 1, 2)
|
156 |
-
return global_normal
|
157 |
-
|
158 |
-
def main():
|
159 |
-
end = time.time()
|
160 |
-
fix_prompt = 'a grid of 2x4 multi-view image. elevation 5. white background.'
|
161 |
-
# user prompt
|
162 |
-
prompt = "a owl wearing a hat."
|
163 |
-
save_dir_path = os.path.join(save_dir, prompt.split(".")[0].replace(" ", "_"))
|
164 |
-
os.makedirs(save_dir_path, exist_ok=True)
|
165 |
-
prompt = fix_prompt+" "+prompt
|
166 |
-
# generate multi-view images
|
167 |
-
rgb_normal_grid = multi_view_rgb_normal_generation(prompt)
|
168 |
-
# lrm reconstructions
|
169 |
-
images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
|
170 |
-
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
|
171 |
-
rgb_multi_view = images[:4, :3, :, :]
|
172 |
-
normal_multi_view = images[4:, :3, :, :]
|
173 |
-
multi_view_mask = get_background(normal_multi_view)
|
174 |
-
rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
|
175 |
-
input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device)
|
176 |
-
vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False)
|
177 |
-
# local normal to global normal
|
178 |
-
|
179 |
-
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
|
180 |
-
global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
|
181 |
-
|
182 |
-
global_normal = global_normal.permute(0,2,3,1)
|
183 |
-
rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
|
184 |
-
multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
|
185 |
-
vertices = torch.from_numpy(vertices).to(device)
|
186 |
-
faces = torch.from_numpy(faces).to(device)
|
187 |
-
vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
|
188 |
-
vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
|
189 |
-
|
190 |
-
# global_normal: B,H,W,3
|
191 |
-
# multi_view_mask: B,H,W
|
192 |
-
# rgb_multi_view: B,H,W,3
|
193 |
-
|
194 |
-
meshes = reconstruction(
|
195 |
-
normal_pils=global_normal,
|
196 |
-
masks=multi_view_mask,
|
197 |
-
weights=isomer_geo_weights,
|
198 |
-
fov=30,
|
199 |
-
radius=isomer_radius,
|
200 |
-
camera_angles_azi=isomer_azimuths,
|
201 |
-
camera_angles_ele=isomer_elevations,
|
202 |
-
expansion_weight_stage1=0.1,
|
203 |
-
init_type="file",
|
204 |
-
init_verts=vertices,
|
205 |
-
init_faces=faces,
|
206 |
-
stage1_steps=0,
|
207 |
-
stage2_steps=50,
|
208 |
-
start_edge_len_stage1=0.1,
|
209 |
-
end_edge_len_stage1=0.02,
|
210 |
-
start_edge_len_stage2=0.02,
|
211 |
-
end_edge_len_stage2=0.005,
|
212 |
-
)
|
213 |
-
|
214 |
-
|
215 |
-
save_glb_addr = projection(
|
216 |
-
meshes,
|
217 |
-
masks=multi_view_mask,
|
218 |
-
images=rgb_multi_view,
|
219 |
-
azimuths=isomer_azimuths,
|
220 |
-
elevations=isomer_elevations,
|
221 |
-
weights=isomer_color_weights,
|
222 |
-
fov=30,
|
223 |
-
radius=isomer_radius,
|
224 |
-
save_dir=f"{save_dir_path}/ISOMER/",
|
225 |
-
)
|
226 |
-
print(f'saved to {save_glb_addr}')
|
227 |
-
print(f"Time elapsed: {time.time() - end:.2f}s")
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
if __name__ == '__main__':
|
232 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_to_mesh_new.py
DELETED
@@ -1,244 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from einops import rearrange
|
3 |
-
from omegaconf import OmegaConf
|
4 |
-
import torch
|
5 |
-
import numpy as np
|
6 |
-
import trimesh
|
7 |
-
import torchvision
|
8 |
-
import torch.nn.functional as F
|
9 |
-
from PIL import Image
|
10 |
-
from torchvision import transforms
|
11 |
-
from torchvision.transforms import v2
|
12 |
-
from diffusers import HeunDiscreteScheduler
|
13 |
-
from diffusers import FluxPipeline
|
14 |
-
from pytorch_lightning import seed_everything
|
15 |
-
import os
|
16 |
-
|
17 |
-
import time
|
18 |
-
|
19 |
-
from models.lrm.utils.infer_util import save_video
|
20 |
-
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
|
21 |
-
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
22 |
-
from models.lrm.utils.train_util import instantiate_from_config
|
23 |
-
from models.lrm.utils.camera_util import get_flux_input_cameras
|
24 |
-
from models.ISOMER.reconstruction_func import reconstruction
|
25 |
-
from models.ISOMER.projection_func import projection
|
26 |
-
from utils.tool import NormalTransfer, load_mipmap
|
27 |
-
from utils.tool import get_background, get_render_cameras_video, render_frames, mask_fix
|
28 |
-
|
29 |
-
device = "cuda"
|
30 |
-
resolution = 512
|
31 |
-
save_dir = "./outputs/text2"
|
32 |
-
normal_transfer = NormalTransfer()
|
33 |
-
isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device)
|
34 |
-
isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device)
|
35 |
-
isomer_radius = 4.5
|
36 |
-
isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
|
37 |
-
isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
|
38 |
-
|
39 |
-
# model initialization and loading
|
40 |
-
# flux
|
41 |
-
flux_pipe = FluxPipeline.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev", torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
|
42 |
-
flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors')
|
43 |
-
|
44 |
-
flux_pipe.to(device=device, dtype=torch.bfloat16)
|
45 |
-
generator = torch.Generator(device=device).manual_seed(10)
|
46 |
-
|
47 |
-
# lrm
|
48 |
-
config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
|
49 |
-
model_config = config.model_config
|
50 |
-
infer_config = config.infer_config
|
51 |
-
model = instantiate_from_config(model_config)
|
52 |
-
model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt"
|
53 |
-
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
54 |
-
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
55 |
-
model.load_state_dict(state_dict, strict=True)
|
56 |
-
|
57 |
-
model = model.to(device)
|
58 |
-
model.init_flexicubes_geometry(device, fovy=50.0)
|
59 |
-
model = model.eval()
|
60 |
-
|
61 |
-
# Flux multi-view generation
|
62 |
-
def multi_view_rgb_normal_generation(prompt, save_path=None):
|
63 |
-
# generate multi-view images
|
64 |
-
with torch.no_grad():
|
65 |
-
image = flux_pipe(
|
66 |
-
prompt=prompt,
|
67 |
-
num_inference_steps=30,
|
68 |
-
guidance_scale=3.5,
|
69 |
-
num_images_per_prompt=1,
|
70 |
-
width=resolution*4,
|
71 |
-
height=resolution*2,
|
72 |
-
output_type='np',
|
73 |
-
generator=generator
|
74 |
-
).images
|
75 |
-
return image
|
76 |
-
|
77 |
-
# lrm reconstructions
|
78 |
-
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
|
79 |
-
images = image.unsqueeze(0).to(device)
|
80 |
-
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
|
81 |
-
# breakpoint()
|
82 |
-
with torch.no_grad():
|
83 |
-
# get triplane
|
84 |
-
planes = model.forward_planes(images, input_cameras)
|
85 |
-
|
86 |
-
mesh_path_idx = os.path.join(save_path, f'{name}.obj')
|
87 |
-
|
88 |
-
mesh_out = model.extract_mesh(
|
89 |
-
planes,
|
90 |
-
use_texture_map=export_texmap,
|
91 |
-
**infer_config,
|
92 |
-
)
|
93 |
-
if export_texmap:
|
94 |
-
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
95 |
-
save_obj_with_mtl(
|
96 |
-
vertices.data.cpu().numpy(),
|
97 |
-
uvs.data.cpu().numpy(),
|
98 |
-
faces.data.cpu().numpy(),
|
99 |
-
mesh_tex_idx.data.cpu().numpy(),
|
100 |
-
tex_map.permute(1, 2, 0).data.cpu().numpy(),
|
101 |
-
mesh_path_idx,
|
102 |
-
)
|
103 |
-
else:
|
104 |
-
vertices, faces, vertex_colors = mesh_out
|
105 |
-
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
|
106 |
-
print(f"Mesh saved to {mesh_path_idx}")
|
107 |
-
|
108 |
-
render_size = 512
|
109 |
-
if if_save_video:
|
110 |
-
video_path_idx = os.path.join(save_path, f'{name}.mp4')
|
111 |
-
render_size = infer_config.render_resolution
|
112 |
-
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
113 |
-
materials = (0.0,0.9)
|
114 |
-
|
115 |
-
all_mv, all_mvp, all_campos = get_render_cameras_video(
|
116 |
-
batch_size=1,
|
117 |
-
M=240,
|
118 |
-
radius=4.5,
|
119 |
-
elevation=(90, 60.0),
|
120 |
-
is_flexicubes=True,
|
121 |
-
fov=30
|
122 |
-
)
|
123 |
-
|
124 |
-
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
125 |
-
model,
|
126 |
-
planes,
|
127 |
-
render_cameras=all_mvp,
|
128 |
-
camera_pos=all_campos,
|
129 |
-
env=ENV,
|
130 |
-
materials=materials,
|
131 |
-
render_size=render_size,
|
132 |
-
chunk_size=20,
|
133 |
-
is_flexicubes=True,
|
134 |
-
)
|
135 |
-
normals = (torch.nn.functional.normalize(normals) + 1) / 2
|
136 |
-
normals = normals * alphas + (1-alphas)
|
137 |
-
all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
|
138 |
-
|
139 |
-
save_video(
|
140 |
-
all_frames,
|
141 |
-
video_path_idx,
|
142 |
-
fps=30,
|
143 |
-
)
|
144 |
-
print(f"Video saved to {video_path_idx}")
|
145 |
-
|
146 |
-
return vertices, faces
|
147 |
-
|
148 |
-
|
149 |
-
def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
|
150 |
-
if local_normal_images.min() >= 0:
|
151 |
-
local_normal = local_normal_images.float() * 2 - 1
|
152 |
-
else:
|
153 |
-
local_normal = local_normal_images.float()
|
154 |
-
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
|
155 |
-
global_normal[...,0] *= -1
|
156 |
-
global_normal = (global_normal + 1) / 2
|
157 |
-
global_normal = global_normal.permute(0, 3, 1, 2)
|
158 |
-
return global_normal
|
159 |
-
|
160 |
-
def main(prompt = "a owl wearing a hat."):
|
161 |
-
fix_prompt = 'a grid of 2x4 multi-view image. elevation 5. white background.'
|
162 |
-
# user prompt
|
163 |
-
|
164 |
-
save_dir_path = os.path.join(save_dir, prompt.split(".")[0].replace(" ", "_"))
|
165 |
-
os.makedirs(save_dir_path, exist_ok=True)
|
166 |
-
prompt = fix_prompt+" "+prompt
|
167 |
-
# generate multi-view images
|
168 |
-
rgb_normal_grid = multi_view_rgb_normal_generation(prompt)
|
169 |
-
# lrm reconstructions
|
170 |
-
images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
|
171 |
-
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
|
172 |
-
rgb_multi_view = images[:4, :3, :, :]
|
173 |
-
normal_multi_view = images[4:, :3, :, :]
|
174 |
-
multi_view_mask = get_background(normal_multi_view)
|
175 |
-
rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
|
176 |
-
input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device)
|
177 |
-
vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False)
|
178 |
-
# local normal to global normal
|
179 |
-
|
180 |
-
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
|
181 |
-
global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
|
182 |
-
|
183 |
-
global_normal = global_normal.permute(0,2,3,1)
|
184 |
-
rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
|
185 |
-
multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
|
186 |
-
vertices = torch.from_numpy(vertices).to(device)
|
187 |
-
faces = torch.from_numpy(faces).to(device)
|
188 |
-
vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
|
189 |
-
vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
|
190 |
-
|
191 |
-
# global_normal: B,H,W,3
|
192 |
-
# multi_view_mask: B,H,W
|
193 |
-
# rgb_multi_view: B,H,W,3
|
194 |
-
|
195 |
-
multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-6, blur=5)
|
196 |
-
|
197 |
-
meshes = reconstruction(
|
198 |
-
normal_pils=global_normal,
|
199 |
-
masks=multi_view_mask,
|
200 |
-
weights=isomer_geo_weights,
|
201 |
-
fov=30,
|
202 |
-
radius=isomer_radius,
|
203 |
-
camera_angles_azi=isomer_azimuths,
|
204 |
-
camera_angles_ele=isomer_elevations,
|
205 |
-
expansion_weight_stage1=0.1,
|
206 |
-
init_type="file",
|
207 |
-
init_verts=vertices,
|
208 |
-
init_faces=faces,
|
209 |
-
stage1_steps=0,
|
210 |
-
stage2_steps=50,
|
211 |
-
start_edge_len_stage1=0.1,
|
212 |
-
end_edge_len_stage1=0.02,
|
213 |
-
start_edge_len_stage2=0.02,
|
214 |
-
end_edge_len_stage2=0.005,
|
215 |
-
)
|
216 |
-
|
217 |
-
|
218 |
-
multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-10, blur=5)
|
219 |
-
|
220 |
-
save_glb_addr = projection(
|
221 |
-
meshes,
|
222 |
-
masks=multi_view_mask_proj,
|
223 |
-
images=rgb_multi_view,
|
224 |
-
azimuths=isomer_azimuths,
|
225 |
-
elevations=isomer_elevations,
|
226 |
-
weights=isomer_color_weights,
|
227 |
-
fov=30,
|
228 |
-
radius=isomer_radius,
|
229 |
-
save_dir=f"{save_dir_path}/ISOMER/",
|
230 |
-
)
|
231 |
-
print(f'saved to {save_glb_addr}')
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
if __name__ == '__main__':
|
236 |
-
import time
|
237 |
-
start_time = time.time()
|
238 |
-
prompts = ["A red dragon soaring", "A running Chihuahua", "A dancing rabbit", "A girl with blue hair and white dress", "A teacher", "A tiger playing guitar", "A red rose", "A red peony", "A rose in a vase", "A golden retriever sitting", "A golden retriever running"]
|
239 |
-
for prompt in prompts:
|
240 |
-
main(prompt)
|
241 |
-
end_time = time.time()
|
242 |
-
print(f"Time taken: {end_time - start_time:.2f} seconds for {len(prompts)} prompts")
|
243 |
-
|
244 |
-
breakpoint()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
upload_huggingface.py
DELETED
@@ -1,57 +0,0 @@
|
|
1 |
-
from huggingface_hub import HfApi, HfFolder, Repository, create_repo, upload_file
|
2 |
-
import os
|
3 |
-
|
4 |
-
# 登录到 Hugging Face
|
5 |
-
from huggingface_hub import login
|
6 |
-
login()
|
7 |
-
|
8 |
-
# 创建或指定现有的 Repository
|
9 |
-
repo_name = "xxx-ckpt"
|
10 |
-
username = "LTT"
|
11 |
-
repo_id = f"{username}/{repo_name}"
|
12 |
-
|
13 |
-
# 创建仓库(如果它不存在)
|
14 |
-
create_repo(repo_id, exist_ok=True)
|
15 |
-
|
16 |
-
# 文件夹
|
17 |
-
# 上传整个文件夹
|
18 |
-
def upload_folder(folder_path, repo_id):
|
19 |
-
"""
|
20 |
-
递归上传文件夹及其内容到 Hugging Face 仓库。
|
21 |
-
"""
|
22 |
-
for root, _, files in os.walk(folder_path):
|
23 |
-
for file in files:
|
24 |
-
# 文件完整路径
|
25 |
-
full_file_path = os.path.join(root, file)
|
26 |
-
# 相对于文件夹的相对路径(保留文件夹结构)
|
27 |
-
relative_path = os.path.relpath(full_file_path, folder_path)
|
28 |
-
|
29 |
-
# 上传文件到仓库
|
30 |
-
print(f"Uploading {relative_path}...")
|
31 |
-
upload_file(
|
32 |
-
path_or_fileobj=full_file_path,
|
33 |
-
path_in_repo=relative_path,
|
34 |
-
repo_id=repo_id
|
35 |
-
)
|
36 |
-
print(f"Uploaded {relative_path} successfully.")
|
37 |
-
|
38 |
-
|
39 |
-
# 上传模型文件
|
40 |
-
model_path = "checkpoint/zero123++/flexgen_19w.ckpt"
|
41 |
-
upload_file(path_or_fileobj=model_path, path_in_repo="flexgen_19w.ckpt", repo_id=repo_id)
|
42 |
-
|
43 |
-
# # 上传数据文件
|
44 |
-
# data_path = "/hpc2hdd/home/jlin695/data/env_map/data/env_mipmap_large.tar.gz"
|
45 |
-
# upload_file(path_or_fileobj=data_path, path_in_repo="env_mipmap_large.tar.gz", repo_id=repo_id)
|
46 |
-
|
47 |
-
# # 上传数据文件
|
48 |
-
# data_path = "/hpc2hdd/home/jlin695/data/env_map/data/env_map_light_large.tar.gz"
|
49 |
-
# upload_file(path_or_fileobj=data_path, path_in_repo="env_map_light_large.tar.gz", repo_id=repo_id)
|
50 |
-
|
51 |
-
# # 定义要上传的文件夹路径
|
52 |
-
# folder_path = "checkpoint/flux_lora"
|
53 |
-
|
54 |
-
# # 调用上传文件夹的函数
|
55 |
-
# upload_folder(folder_path, repo_id)
|
56 |
-
|
57 |
-
# print("模型和数据文件已上传到 Hugging Face。")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_render.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import imageio
|
5 |
+
import trimesh
|
6 |
+
import pyrender
|
7 |
+
from tqdm import tqdm
|
8 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "7"
|
9 |
+
os.environ['PYOPENGL_PLATFORM'] = 'egl' # 设置渲染环境为 EGL(无头模式)
|
10 |
+
|
11 |
+
def render_video_from_obj(input_obj_path, output_video_path, fps=15, frame_count=60, resolution=(512, 512)):
|
12 |
+
"""
|
13 |
+
Render a rotating 3D model (OBJ file) to a video with RGB and normal map side-by-side.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
input_obj_path (str): Path to the input OBJ file.
|
17 |
+
output_video_path (str): Path to save the output video.
|
18 |
+
fps (int): Frames per second for the video.
|
19 |
+
frame_count (int): Number of frames in the video.
|
20 |
+
resolution (tuple): Resolution of the rendered video (width, height).
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
str: Path to the output video.
|
24 |
+
"""
|
25 |
+
# 检查输入文件是否存在
|
26 |
+
if not os.path.exists(input_obj_path):
|
27 |
+
raise FileNotFoundError(f"Input OBJ file not found: {input_obj_path}")
|
28 |
+
|
29 |
+
# 加载3D模型
|
30 |
+
scene_data = trimesh.load(input_obj_path)
|
31 |
+
|
32 |
+
# 提取或合并网格
|
33 |
+
if isinstance(scene_data, trimesh.Scene):
|
34 |
+
mesh_data = trimesh.util.concatenate([geom for geom in scene_data.geometry.values()])
|
35 |
+
else:
|
36 |
+
mesh_data = scene_data
|
37 |
+
|
38 |
+
# 确保顶点法线存在
|
39 |
+
if not hasattr(mesh_data, 'vertex_normals') or mesh_data.vertex_normals is None:
|
40 |
+
mesh_data.compute_vertex_normals()
|
41 |
+
|
42 |
+
# 创建 Pyrender 场景并设置背景为白色
|
43 |
+
render_scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0])
|
44 |
+
mesh = pyrender.Mesh.from_trimesh(mesh_data, smooth=True)
|
45 |
+
mesh_node = render_scene.add(mesh)
|
46 |
+
|
47 |
+
# 设置摄像机参数
|
48 |
+
camera = pyrender.PerspectiveCamera(yfov=np.deg2rad(30), znear=0.0001, zfar=100000.0)
|
49 |
+
camera_pose = np.eye(4)
|
50 |
+
camera_pose[2, 3] = 4.0 # 距离模型 20 个单位
|
51 |
+
render_scene.add(camera, pose=camera_pose)
|
52 |
+
|
53 |
+
# 添加全局环境光
|
54 |
+
ambient_light = np.array([1.0, 1.0, 1.0]) * 2.0
|
55 |
+
render_scene.ambient_light = ambient_light
|
56 |
+
|
57 |
+
# 准备法线渲染场景
|
58 |
+
normals = mesh_data.vertex_normals.copy()
|
59 |
+
|
60 |
+
# 将法线映射到颜色范围 [0, 255]
|
61 |
+
normal_colors = ((normals + 1) / 2 * 255)
|
62 |
+
|
63 |
+
# 创建用于法线渲染的独立网格
|
64 |
+
normal_mesh_data = mesh_data.copy()
|
65 |
+
normal_mesh_data.visual.vertex_colors = np.hstack(
|
66 |
+
[normal_colors, np.full((normals.shape[0], 1), 255, dtype=np.uint8)] # 添加 Alpha 通道
|
67 |
+
)
|
68 |
+
|
69 |
+
# 创建法线渲染场景
|
70 |
+
normal_scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 1.0])
|
71 |
+
normal_mesh = pyrender.Mesh.from_trimesh(normal_mesh_data, smooth=True)
|
72 |
+
normal_mesh_node = normal_scene.add(normal_mesh)
|
73 |
+
normal_scene.add(camera, pose=camera_pose)
|
74 |
+
normal_scene.ambient_light = ambient_light
|
75 |
+
|
76 |
+
# 初始化渲染器
|
77 |
+
r = pyrender.OffscreenRenderer(*resolution)
|
78 |
+
|
79 |
+
# 创建视频写入器
|
80 |
+
writer = imageio.get_writer(output_video_path, fps=fps)
|
81 |
+
|
82 |
+
# 渲染每一帧
|
83 |
+
try:
|
84 |
+
for frame_idx in tqdm(range(frame_count)):
|
85 |
+
# 计算旋转角度
|
86 |
+
angle = 2 * np.pi * frame_idx / frame_count
|
87 |
+
rotation_matrix = np.array([
|
88 |
+
[math.cos(angle), 0, math.sin(angle), 0],
|
89 |
+
[0, 1, 0, 0],
|
90 |
+
[-math.sin(angle), 0, math.cos(angle), 0],
|
91 |
+
[0, 0, 0, 1]
|
92 |
+
])
|
93 |
+
|
94 |
+
# 更新模型的姿态
|
95 |
+
render_scene.set_pose(mesh_node, rotation_matrix)
|
96 |
+
|
97 |
+
# 渲染 RGB 图像
|
98 |
+
color, _ = r.render(render_scene)
|
99 |
+
|
100 |
+
# 更新法线场景的姿态
|
101 |
+
normal_scene.set_pose(normal_mesh_node, rotation_matrix)
|
102 |
+
|
103 |
+
# 渲染法线图像
|
104 |
+
normal, _ = r.render(normal_scene, flags=pyrender.RenderFlags.FLAT)
|
105 |
+
|
106 |
+
# 拼接左右图像
|
107 |
+
combined_frame = np.concatenate((color, normal), axis=1)
|
108 |
+
|
109 |
+
# 写入视频帧
|
110 |
+
writer.append_data(combined_frame)
|
111 |
+
finally:
|
112 |
+
# 释放资源
|
113 |
+
writer.close()
|
114 |
+
r.delete()
|
115 |
+
|
116 |
+
print(f"Rendered video saved to {output_video_path}")
|
117 |
+
return output_video_path
|
118 |
+
|
119 |
+
if __name__ == '__main__':
|
120 |
+
# 示例调用
|
121 |
+
input_obj_path = "output/gradio_cache/text_3D/_超级赛亚人_10/rgb_projected.obj"
|
122 |
+
output_video_path = "output.mp4"
|
123 |
+
render_video_from_obj(input_obj_path, output_video_path)
|