JiantaoLin commited on
Commit
02a9751
·
1 Parent(s): e33401c
app.py CHANGED
@@ -1,9 +1,35 @@
1
- import gradio as gr
2
  import os
 
3
  import subprocess
4
- import shlex
5
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  access_token = os.getenv("HUGGINGFACE_TOKEN")
8
  subprocess.run(
9
  shlex.split(
@@ -22,6 +48,8 @@ subprocess.run(
22
  "pip install ./extension/renderutils_plugin-0.1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
23
  )
24
  )
 
 
25
  def install_cuda_toolkit():
26
  # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
27
  # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
@@ -41,6 +69,7 @@ def install_cuda_toolkit():
41
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
42
  print("==> finfish install")
43
  install_cuda_toolkit()
 
44
  @spaces.GPU
45
  def check_gpu():
46
  os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
@@ -48,338 +77,417 @@ def check_gpu():
48
  # os.environ['LD_LIBRARY_PATH'] += ':/usr/local/cuda-12.1/lib64'
49
  os.environ['LD_LIBRARY_PATH'] = "/usr/local/cuda-12.1/lib64:" + os.environ.get('LD_LIBRARY_PATH', '')
50
  subprocess.run(['nvidia-smi']) # 测试 CUDA 是否可用
 
 
 
 
 
 
 
51
  print(f"torch.cuda.is_available:{torch.cuda.is_available()}")
52
  check_gpu()
53
 
54
- from PIL import Image
55
- from einops import rearrange
56
- from diffusers import FluxPipeline
57
- from models.lrm.utils.camera_util import get_flux_input_cameras
58
- from models.lrm.utils.infer_util import save_video
59
- from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
60
- from models.lrm.utils.render_utils import rotate_x, rotate_y
61
- from models.lrm.utils.train_util import instantiate_from_config
62
- from models.ISOMER.reconstruction_func import reconstruction
63
- from models.ISOMER.projection_func import projection
64
- import os
65
- from einops import rearrange
66
- from omegaconf import OmegaConf
67
- import torch
68
- import numpy as np
69
- import trimesh
70
- import torchvision
71
- import torch.nn.functional as F
72
- from PIL import Image
73
- from torchvision import transforms
74
- from torchvision.transforms import v2
75
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
76
- from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
77
- from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
78
- from diffusers import FluxPipeline
79
- from pytorch_lightning import seed_everything
80
- import os
81
- from huggingface_hub import hf_hub_download
82
-
83
-
84
- from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
85
-
86
- device_0 = "cuda"
87
- device_1 = "cuda"
88
- resolution = 512
89
- save_dir = "./outputs"
90
- normal_transfer = NormalTransfer()
91
- isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device_1)
92
- isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device_1)
93
- isomer_radius = 4.5
94
- isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device_1)
95
- isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device_1)
96
-
97
- # model initialization and loading
98
- # flux
99
- # # taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device_0)
100
- # # good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, token=access_token).to(device_0)
101
- # flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(device=device_0, dtype=torch.bfloat16)
102
- # # flux_pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, vae=taef1, token=access_token).to(device_0)
103
- # flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
104
- # flux_pipe.load_lora_weights(flux_lora_ckpt_path)
105
- # flux_pipe.to(device=device_0, dtype=torch.bfloat16)
106
- # torch.cuda.empty_cache()
107
- # flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(flux_pipe)
108
-
109
-
110
- # lrm
111
- config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
112
- model_config = config.model_config
113
- infer_config = config.infer_config
114
- model = instantiate_from_config(model_config)
115
- model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
116
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
117
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
118
- model.load_state_dict(state_dict, strict=True)
119
- model = model.to(device_1)
120
- torch.cuda.empty_cache()
121
- @spaces.GPU
122
- def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
123
- images = image.unsqueeze(0).to(device_1)
124
- images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
125
- # breakpoint()
126
- with torch.no_grad():
127
- # get triplane
128
- planes = model.forward_planes(images, input_cameras)
129
-
130
- mesh_path_idx = os.path.join(save_path, f'{name}.obj')
131
-
132
- mesh_out = model.extract_mesh(
133
- planes,
134
- use_texture_map=export_texmap,
135
- **infer_config,
136
- )
137
- if export_texmap:
138
- vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
139
- save_obj_with_mtl(
140
- vertices.data.cpu().numpy(),
141
- uvs.data.cpu().numpy(),
142
- faces.data.cpu().numpy(),
143
- mesh_tex_idx.data.cpu().numpy(),
144
- tex_map.permute(1, 2, 0).data.cpu().numpy(),
145
- mesh_path_idx,
146
- )
147
- else:
148
- vertices, faces, vertex_colors = mesh_out
149
- save_obj(vertices, faces, vertex_colors, mesh_path_idx)
150
- print(f"Mesh saved to {mesh_path_idx}")
151
-
152
- render_size = 512
153
- if if_save_video:
154
- video_path_idx = os.path.join(save_path, f'{name}.mp4')
155
- render_size = infer_config.render_resolution
156
- ENV = load_mipmap("models/lrm/env_mipmap/6")
157
- materials = (0.0,0.9)
158
-
159
- all_mv, all_mvp, all_campos = get_render_cameras_video(
160
- batch_size=1,
161
- M=24,
162
- radius=4.5,
163
- elevation=(90, 60.0),
164
- is_flexicubes=True,
165
- fov=30
166
- )
167
-
168
- frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
169
- model,
170
- planes,
171
- render_cameras=all_mvp,
172
- camera_pos=all_campos,
173
- env=ENV,
174
- materials=materials,
175
- render_size=render_size,
176
- chunk_size=20,
177
- is_flexicubes=True,
178
- )
179
- normals = (torch.nn.functional.normalize(normals) + 1) / 2
180
- normals = normals * alphas + (1-alphas)
181
- all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
182
-
183
- save_video(
184
- all_frames,
185
- video_path_idx,
186
- fps=30,
187
- )
188
- print(f"Video saved to {video_path_idx}")
189
 
190
- return vertices, faces
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
- def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
194
- if local_normal_images.min() >= 0:
195
- local_normal = local_normal_images.float() * 2 - 1
196
- else:
197
- local_normal = local_normal_images.float()
198
- global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
199
- global_normal[...,0] *= -1
200
- global_normal = (global_normal + 1) / 2
201
- global_normal = global_normal.permute(0, 3, 1, 2)
202
- return global_normal
203
-
204
- # 生成多视图图像
205
- @spaces.GPU(duration=120)
206
- def generate_multi_view_images(prompt, seed):
207
- # torch.cuda.empty_cache()
208
- # generator = torch.manual_seed(seed)
209
- generator = torch.Generator().manual_seed(seed)
210
- with torch.no_grad():
211
- img = flux_pipe(
212
- prompt=prompt,
213
- num_inference_steps=5,
214
- guidance_scale=3.5,
215
- num_images_per_prompt=1,
216
- width=resolution * 2,
217
- height=resolution * 1,
218
- output_type='np',
219
- generator=generator,
220
- ).images
221
- # for img in flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images(
222
- # prompt=prompt,
223
- # guidance_scale=3.5,
224
- # num_inference_steps=4,
225
- # width=resolution * 4,
226
- # height=resolution * 2,
227
- # generator=generator,
228
- # output_type="np",
229
- # good_vae=good_vae,
230
- # ):
231
- # pass
232
- # 返回最终的图像和种子(通过外部调用处理)
233
- return img
234
-
235
- # 重建 3D 模型
236
  @spaces.GPU
237
- def reconstruct_3d_model(images, prompt):
238
- global model
239
- model.init_flexicubes_geometry(device_1, fovy=50.0)
240
- model = model.eval()
241
- rgb_normal_grid = images
242
- save_dir_path = os.path.join(save_dir, prompt.replace(" ", "_"))
243
- os.makedirs(save_dir_path, exist_ok=True)
244
-
245
- images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
246
- images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
247
- rgb_multi_view = images[:4, :3, :, :]
248
- normal_multi_view = images[4:, :3, :, :]
249
- multi_view_mask = get_background(normal_multi_view)
250
- rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
251
- input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device_1)
252
- vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=True)
253
- # local normal to global normal
254
-
255
- global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
256
- global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
257
-
258
- global_normal = global_normal.permute(0,2,3,1)
259
- rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
260
- multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
261
- vertices = torch.from_numpy(vertices).to(device_1)
262
- faces = torch.from_numpy(faces).to(device_1)
263
- vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
264
- vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
265
-
266
- # global_normal: B,H,W,3
267
- # multi_view_mask: B,H,W
268
- # rgb_multi_view: B,H,W,3
269
 
270
- meshes = reconstruction(
271
- normal_pils=global_normal,
272
- masks=multi_view_mask,
273
- weights=isomer_geo_weights,
274
- fov=30,
275
- radius=isomer_radius,
276
- camera_angles_azi=isomer_azimuths,
277
- camera_angles_ele=isomer_elevations,
278
- expansion_weight_stage1=0.1,
279
- init_type="file",
280
- init_verts=vertices,
281
- init_faces=faces,
282
- stage1_steps=0,
283
- stage2_steps=50,
284
- start_edge_len_stage1=0.1,
285
- end_edge_len_stage1=0.02,
286
- start_edge_len_stage2=0.02,
287
- end_edge_len_stage2=0.005,
288
- )
289
 
 
 
290
 
291
- save_glb_addr = projection(
292
- meshes,
293
- masks=multi_view_mask,
294
- images=rgb_multi_view,
295
- azimuths=isomer_azimuths,
296
- elevations=isomer_elevations,
297
- weights=isomer_color_weights,
298
- fov=30,
299
- radius=isomer_radius,
300
- save_dir=f"{save_dir_path}/ISOMER/",
301
- )
302
 
303
- return save_glb_addr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
- # Gradio 接口函数
306
  @spaces.GPU
307
- def gradio_pipeline(prompt, seed):
308
- import ctypes
309
- # 显式加载 libnvrtc.so.12
310
- cuda_lib_path = "/usr/local/cuda-12.1/lib64/libnvrtc.so.12"
311
- try:
312
- ctypes.CDLL(cuda_lib_path, mode=ctypes.RTLD_GLOBAL)
313
- print(f"Successfully preloaded {cuda_lib_path}")
314
- except OSError as e:
315
- print(f"Failed to preload {cuda_lib_path}: {e}")
316
- # 生成多视图图像
317
- # rgb_normal_grid = generate_multi_view_images(prompt, seed)
318
- rgb_normal_grid = np.load("rgb_normal_grid.npy")
319
- image_preview = Image.fromarray((rgb_normal_grid[0] * 255).astype(np.uint8))
320
-
321
- # 3d reconstruction
322
-
323
-
324
- # 重建 3D 模型并返回 glb 路径
325
- save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt)
326
- # save_glb_addr = None
327
- return image_preview, save_glb_addr
328
-
329
- # Gradio Blocks 应用
330
- with gr.Blocks() as demo:
331
- with gr.Row(variant="panel"):
332
- # 左侧输入区域
333
- with gr.Column():
334
- with gr.Row():
335
- prompt_input = gr.Textbox(
336
- label="Enter Prompt",
337
- placeholder="Describe your 3D model...",
338
- lines=2,
339
- elem_id="prompt_input"
340
- )
341
-
342
- with gr.Row():
343
- sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
344
-
345
- with gr.Row():
346
- submit = gr.Button("Generate", elem_id="generate", variant="primary")
347
-
348
- with gr.Row(variant="panel"):
349
- gr.Markdown("Examples:")
350
- gr.Examples(
351
- examples=[
352
- ["a castle on a hill"],
353
- ["an owl wearing a hat"],
354
- ["a futuristic car"]
355
- ],
356
- inputs=[prompt_input],
357
- label="Prompt Examples"
358
- )
359
-
360
- # 右侧输出区域
361
- with gr.Column():
362
- with gr.Row():
363
- rgb_normal_grid_image = gr.Image(
364
- label="RGB Normal Grid",
365
- type="pil",
366
- interactive=False
367
- )
368
-
369
- with gr.Row():
370
- with gr.Tab("GLB"):
371
- output_glb_model = gr.Model3D(
372
- label="Generated 3D Model (GLB Format)",
373
- interactive=False
374
- )
375
- gr.Markdown("Download the model for proper visualization.")
376
-
377
- # 处理逻辑
378
- submit.click(
379
- fn=gradio_pipeline, inputs=[prompt_input, sample_seed],
380
- outputs=[rgb_normal_grid_image, output_glb_model]
381
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
- # 启动应用
384
- # demo.queue(max_size=10)
385
- demo.launch()
 
 
1
  import os
2
+ import gradio as gr
3
  import subprocess
 
4
  import spaces
5
+ import ctypes
6
+ import shlex
7
+ import base64
8
+ import re
9
+ import sys
10
+
11
+ from models.ISOMER.scripts.utils import fix_vert_color_glb
12
+
13
+ sys.path.append(os.path.abspath(os.path.join(__file__, '../')))
14
+ if 'OMP_NUM_THREADS' not in os.environ:
15
+ os.environ['OMP_NUM_THREADS'] = '32'
16
+
17
+ import shutil
18
  import torch
19
+ import json
20
+ import requests
21
+ import shutil
22
+ import threading
23
+ from PIL import Image
24
+ import time
25
+ torch.backends.cuda.matmul.allow_tf32 = True
26
+ import trimesh
27
+
28
+ import random
29
+ import time
30
+ import numpy as np
31
+ from video_render import render_video_from_obj
32
+
33
  access_token = os.getenv("HUGGINGFACE_TOKEN")
34
  subprocess.run(
35
  shlex.split(
 
48
  "pip install ./extension/renderutils_plugin-0.1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
49
  )
50
  )
51
+
52
+ # download cudatoolkit
53
  def install_cuda_toolkit():
54
  # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
55
  # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
 
69
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
70
  print("==> finfish install")
71
  install_cuda_toolkit()
72
+
73
  @spaces.GPU
74
  def check_gpu():
75
  os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
 
77
  # os.environ['LD_LIBRARY_PATH'] += ':/usr/local/cuda-12.1/lib64'
78
  os.environ['LD_LIBRARY_PATH'] = "/usr/local/cuda-12.1/lib64:" + os.environ.get('LD_LIBRARY_PATH', '')
79
  subprocess.run(['nvidia-smi']) # 测试 CUDA 是否可用
80
+ # 显式加载 libnvrtc.so.12
81
+ cuda_lib_path = "/usr/local/cuda-12.1/lib64/libnvrtc.so.12"
82
+ try:
83
+ ctypes.CDLL(cuda_lib_path, mode=ctypes.RTLD_GLOBAL)
84
+ print(f"Successfully preloaded {cuda_lib_path}")
85
+ except OSError as e:
86
+ print(f"Failed to preload {cuda_lib_path}: {e}")
87
  print(f"torch.cuda.is_available:{torch.cuda.is_available()}")
88
  check_gpu()
89
 
90
+ from pipeline.kiss3d_wrapper import init_wrapper_from_config, run_text_to_3d, run_image_to_3d, image2mesh_preprocess, image2mesh_main
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ is_running = False
93
 
94
+ TEXT_URL = "http://127.0.0.1:9239/prompt"
95
+ IMG_URL = ""
96
+
97
+
98
+ KISS_3D_TEXT_FOLDER = "./outputs/text2"
99
+ KISS_3D_IMG_FOLDER = "./outputs/image2"
100
+
101
+ # Add logo file path and hyperlinks
102
+ LOGO_PATH = "app_assets/logo_temp_.png" # Update this to the actual path of your logo
103
+ ARXIV_LINK = "https://arxiv.org/abs/example"
104
+ GITHUB_LINK = "https://github.com/example"
105
+
106
+ k3d_wrapper = init_wrapper_from_config('./pipeline/pipeline_config/default.yaml')
107
+
108
+
109
+ TEMP_MESH_ADDRESS=''
110
+
111
+ mesh_cache = None
112
+ preprocessed_input_image = None
113
+
114
+ def save_cached_mesh():
115
+ global mesh_cache
116
+ return mesh_cache
117
+ # if mesh_cache is None:
118
+ # return None
119
+ # return save_py3dmesh_with_trimesh_fast(mesh_cache)
120
+
121
+ def save_py3dmesh_with_trimesh_fast(meshes, save_glb_path=TEMP_MESH_ADDRESS, apply_sRGB_to_LinearRGB=True):
122
+ from pytorch3d.structures import Meshes
123
+ import trimesh
124
+
125
+ # convert from pytorch3d meshes to trimesh mesh
126
+ vertices = meshes.verts_packed().cpu().float().numpy()
127
+ triangles = meshes.faces_packed().cpu().long().numpy()
128
+ np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
129
+ if save_glb_path.endswith(".glb"):
130
+ # rotate 180 along +Y
131
+ vertices[:, [0, 2]] = -vertices[:, [0, 2]]
132
+
133
+ def srgb_to_linear(c_srgb):
134
+ c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
135
+ return c_linear.clip(0, 1.)
136
+ if apply_sRGB_to_LinearRGB:
137
+ np_color = srgb_to_linear(np_color)
138
+ assert vertices.shape[0] == np_color.shape[0]
139
+ assert np_color.shape[1] == 3
140
+ assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}"
141
+ mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
142
+ mesh.remove_unreferenced_vertices()
143
+ # save mesh
144
+ mesh.export(save_glb_path)
145
+ if save_glb_path.endswith(".glb"):
146
+ fix_vert_color_glb(save_glb_path)
147
+ print(f"saving to {save_glb_path}")
148
+ #
149
+ #
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  @spaces.GPU
152
+ def text_to_detailed(prompt, seed=None):
153
+ print(f"Before text_to_detailed: {torch.cuda.memory_allocated() / 1024**3} GB")
154
+ return k3d_wrapper.get_detailed_prompt(prompt, seed)
155
+
156
+ @spaces.GPU
157
+ def text_to_image(prompt, seed=None, strength=1.0,lora_scale=1.0, num_inference_steps=30, redux_hparam=None, init_image=None, **kwargs):
158
+ print(f"Before text_to_image: {torch.cuda.memory_allocated() / 1024**3} GB")
159
+ k3d_wrapper.renew_uuid()
160
+ init_image = None
161
+ if init_image_path is not None:
162
+ init_image = Image.open(init_image_path)
163
+ result = k3d_wrapper.generate_3d_bundle_image_text(
164
+ prompt,
165
+ image=init_image,
166
+ strength=strength,
167
+ lora_scale=lora_scale,
168
+ num_inference_steps=num_inference_steps,
169
+ seed=int(seed) if seed is not None else None,
170
+ redux_hparam=redux_hparam,
171
+ save_intermediate_results=True,
172
+ **kwargs)
173
+ return result[-1]
174
+
175
+ def image2mesh_preprocess_(input_image_, seed, use_mv_rgb=True):
176
+ global preprocessed_input_image
177
+
178
+ seed = int(seed) if seed is not None else None
179
+
180
+ # TODO: delete this later
181
+ k3d_wrapper.del_llm_model()
 
 
182
 
183
+ input_image_save_path, reference_save_path, caption = image2mesh_preprocess(k3d_wrapper, input_image_, seed, use_mv_rgb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ preprocessed_input_image = Image.open(input_image_save_path)
186
+ return reference_save_path, caption
187
 
188
+ @spaces.GPU
189
+ def image2mesh_main_(reference_3d_bundle_image, caption, seed, strength1=0.5, strength2=0.95, enable_redux=True, use_controlnet=True, if_video=True):
190
+ global mesh_cache
191
+ seed = int(seed) if seed is not None else None
192
+
193
+
194
+ # TODO: delete this later
195
+ k3d_wrapper.del_llm_model()
196
+
197
+ input_image = preprocessed_input_image
 
198
 
199
+ reference_3d_bundle_image = torch.tensor(reference_3d_bundle_image).permute(2,0,1)/255
200
+
201
+ gen_save_path, recon_mesh_path = image2mesh_main(k3d_wrapper, input_image, reference_3d_bundle_image, caption=caption, seed=seed, strength1=strength1, strength2=strength2, enable_redux=enable_redux, use_controlnet=use_controlnet)
202
+ mesh_cache = recon_mesh_path
203
+
204
+
205
+ # gen_save_ = Image.open(gen_save_path)
206
+
207
+ if if_video:
208
+ video_path = recon_mesh_path.replace('.obj','.mp4').replace('.glb','.mp4')
209
+ render_video_from_obj(recon_mesh_path, video_path)
210
+ print(f"After bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
211
+ return gen_save_path, video_path
212
+ else:
213
+ return gen_save_path, recon_mesh_path
214
+ # return gen_save_path, recon_mesh_path
215
 
 
216
  @spaces.GPU
217
+ def bundle_image_to_mesh(
218
+ gen_3d_bundle_image,
219
+ lrm_radius = 4.15,
220
+ isomer_radius = 4.5,
221
+ reconstruction_stage1_steps = 10,
222
+ reconstruction_stage2_steps = 50,
223
+ save_intermediate_results=True,
224
+ if_video=True
225
+ ):
226
+ global mesh_cache
227
+ print(f"Before bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
228
+
229
+ # TODO: delete this later
230
+ k3d_wrapper.del_llm_model()
231
+
232
+ print(f"Before bundle_image_to_mesh after deleting llm model: {torch.cuda.memory_allocated() / 1024**3} GB")
233
+
234
+ gen_3d_bundle_image = torch.tensor(gen_3d_bundle_image).permute(2,0,1)/255
235
+ # recon from 3D Bundle image
236
+ recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, lrm_render_radius=lrm_radius, isomer_radius=isomer_radius, save_intermediate_results=save_intermediate_results, reconstruction_stage1_steps=int(reconstruction_stage1_steps), reconstruction_stage2_steps=int(reconstruction_stage2_steps))
237
+ mesh_cache = recon_mesh_path
238
+
239
+ if if_video:
240
+ video_path = recon_mesh_path.replace('.obj','.mp4').replace('.glb','.mp4')
241
+ # # 检查这个video_path文件大小是是否超过50KB,不超过的话就认为是空文件,需要重新渲染
242
+ # if os.path.exists(video_path):
243
+ # print(f"file size:{os.path.getsize(video_path)}")
244
+ # if os.path.getsize(video_path) > 50*1024:
245
+ # print(f"video path:{video_path}")
246
+ # return video_path
247
+ render_video_from_obj(recon_mesh_path, video_path)
248
+ print(f"After bundle_image_to_mesh: {torch.cuda.memory_allocated() / 1024**3} GB")
249
+ return video_path
250
+ else:
251
+ return recon_mesh_path
252
+
253
+ _HEADER_=f"""
254
+ <img src="{LOGO_PATH}">
255
+ <h2><b>Official 🤗 Gradio Demo</b></h2><h2>
256
+ <b>Kiss3DGen: Repurposing Image Diffusion Models for 3D Asset Generation</b></a></h2>
257
+
258
+ <p>**Kiss3DGen** is xxxxxxxxx</p>
259
+
260
+ [![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK}) [![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})
261
+ """
262
+
263
+ _CITE_ = r"""
264
+ <h2>If Kiss3DGen is helpful, please help to ⭐ the <a href='{""" + GITHUB_LINK + r"""}' target='_blank'>Github Repo</a>. Thanks!</h2>
265
+
266
+ 📝 **Citation**
267
+
268
+ If you find our work useful for your research or applications, please cite using this bibtex:
269
+ ```bibtex
270
+ @article{xxxx,
271
+ title={xxxx},
272
+ author={xxxx},
273
+ journal={xxxx},
274
+ year={xxxx}
275
+ }
276
+ ```
277
+
278
+ 📋 **License**
279
+
280
+ Apache-2.0 LICENSE. Please refer to the [LICENSE file](https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/LICENSE) for details.
281
+
282
+ 📧 **Contact**
283
+
284
+ If you have any questions, feel free to open a discussion or contact us at <b>xxx@xxxx</b>.
285
+ """
286
+
287
+ def image_to_base64(image_path):
288
+ """Converts an image file to a base64-encoded string."""
289
+ with open(image_path, "rb") as img_file:
290
+ return base64.b64encode(img_file.read()).decode('utf-8')
291
+
292
+ def main():
293
+
294
+ torch.set_grad_enabled(False)
295
+
296
+ # Convert the logo image to base64
297
+ logo_base64 = image_to_base64(LOGO_PATH)
298
+ # with gr.Blocks() as demo:
299
+ with gr.Blocks(css="""
300
+ body {
301
+ display: flex;
302
+ justify-content: center;
303
+ align-items: center;
304
+ min-height: 100vh;
305
+ margin: 0;
306
+ padding: 0;
307
+ }
308
+ #col-container { margin: 0px auto; max-width: 200px; }
309
+
310
+
311
+ .gradio-container {
312
+ max-width: 1000px;
313
+ margin: auto;
314
+ width: 100%;
315
+ }
316
+ #center-align-column {
317
+ display: flex;
318
+ justify-content: center;
319
+ align-items: center;
320
+ }
321
+ #right-align-column {
322
+ display: flex;
323
+ justify-content: flex-end;
324
+ align-items: center;
325
+ }
326
+ h1 {text-align: center;}
327
+ h2 {text-align: center;}
328
+ h3 {text-align: center;}
329
+ p {text-align: center;}
330
+ img {text-align: right;}
331
+ .right {
332
+ display: block;
333
+ margin-left: auto;
334
+ }
335
+ .center {
336
+ display: block;
337
+ margin-left: auto;
338
+ margin-right: auto;
339
+ width: 50%;
340
+
341
+ #content-container {
342
+ max-width: 1200px;
343
+ margin: 0 auto;
344
+ }
345
+ #example-container {
346
+ max-width: 300px;
347
+ margin: 0 auto;
348
+ }
349
+ """,elem_id="col-container") as demo:
350
+ # Header Section
351
+ # gr.Image(value=LOGO_PATH, width=64, height=64)
352
+ # gr.Markdown(_HEADER_)
353
+ with gr.Row(elem_id="content-container"):
354
+ # with gr.Column(scale=1):
355
+ # pass
356
+ # with gr.Column(scale=1, elem_id="right-align-column"):
357
+ # # gr.Image(value=LOGO_PATH, interactive=False, show_label=False, width=64, height=64, elem_id="logo-image")
358
+ # # gr.Markdown(f"<img src='{LOGO_PATH}' alt='Logo' style='width:64px;height:64px;border:0;'>")
359
+ # # gr.HTML(f"<img src='data:image/png;base64,{logo_base64}' alt='Logo' class='right' style='width:64px;height:64px;border:0;text-align:right;'>")
360
+ # pass
361
+ with gr.Column(scale=7, elem_id="center-align-column"):
362
+ gr.Markdown(f"""
363
+ ## Official 🤗 Gradio Demo
364
+ # Kiss3DGen: Repurposing Image Diffusion Models for 3D Asset Generation""")
365
+ gr.HTML(f"<img src='data:image/png;base64,{logo_base64}' alt='Logo' class='center' style='width:64px;height:64px;border:0;text-align:center;'>")
366
+
367
+ gr.HTML(f"""
368
+ <div style="display: flex; justify-content: center; align-items: center; gap: 10px;">
369
+ <a href="{ARXIV_LINK}" target="_blank">
370
+ <img src="https://img.shields.io/badge/arXiv-Link-red" alt="arXiv">
371
+ </a>
372
+ <a href="{GITHUB_LINK}" target="_blank">
373
+ <img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub">
374
+ </a>
375
+ </div>
376
+
377
+ """)
378
+
379
+
380
+ # gr.HTML(f"""
381
+ # <div style="display: flex; gap: 10px; align-items: center;"><a href="{ARXIV_LINK}" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/arXiv-Link-red" alt="arXiv"></a> <a href="{GITHUB_LINK}" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub"></a></div>
382
+ # """)
383
+
384
+ # gr.Markdown(f"""
385
+ # [![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK}) [![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})
386
+ # """, elem_id="title")
387
+ # with gr.Column(scale=1):
388
+ # pass
389
+ # with gr.Row():
390
+ # gr.Markdown(f"[![arXiv](https://img.shields.io/badge/arXiv-Link-red)]({ARXIV_LINK})")
391
+ # gr.Markdown(f"[![GitHub](https://img.shields.io/badge/GitHub-Repo-blue)]({GITHUB_LINK})")
392
+
393
+ # Tabs Section
394
+ with gr.Tabs(selected='tab_text_to_3d', elem_id="content-container") as main_tabs:
395
+ with gr.TabItem('Text-to-3D', id='tab_text_to_3d'):
396
+ with gr.Row():
397
+ with gr.Column(scale=1):
398
+ prompt = gr.Textbox(value="", label="Input Prompt", lines=4)
399
+ seed1 = gr.Number(value=10, label="Seed")
400
+
401
+ with gr.Row(elem_id="example-container"):
402
+ gr.Examples(
403
+ examples=[
404
+ # ["A tree with red leaves"],
405
+ # ["A dragon with black texture"],
406
+ ["A girl with pink hair"],
407
+ ["A boy playing guitar"],
408
+
409
+
410
+ ["A dog wearing a hat"],
411
+ ["A boy playing basketball"],
412
+ # [""],
413
+ # [""],
414
+ # [""],
415
+
416
+ ],
417
+ inputs=[prompt], # 将选中的示例填入 prompt 文本框
418
+ label="Example Prompts"
419
+ )
420
+ btn_text2detailed = gr.Button("Refine to detailed prompt")
421
+ detailed_prompt = gr.Textbox(value="", label="Detailed Prompt", placeholder="detailed prompt will be generated here base on your input prompt. You can also edit this prompt", lines=4, interactive=True)
422
+ btn_text2img = gr.Button("Generate Images")
423
+
424
+ with gr.Column(scale=1):
425
+ output_image1 = gr.Image(label="Generated image", interactive=False)
426
+
427
+
428
+ # lrm_radius = gr.Number(value=4.15, label="lrm_radius")
429
+ # isomer_radius = gr.Number(value=4.5, label="isomer_radius")
430
+ # reconstruction_stage1_steps = gr.Number(value=10, label="reconstruction_stage1_steps")
431
+ # reconstruction_stage2_steps = gr.Number(value=50, label="reconstruction_stage2_steps")
432
+
433
+ btn_gen_mesh = gr.Button("Generate Mesh")
434
+ output_video1 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
435
+ btn_download1 = gr.Button("Download Mesh")
436
+
437
+ file_output1 = gr.File()
438
+
439
+ with gr.TabItem('Image-to-3D', id='tab_image_to_3d'):
440
+ with gr.Row():
441
+ with gr.Column(scale=1):
442
+ image = gr.Image(label="Input Image", type="pil")
443
+
444
+ seed2 = gr.Number(value=10, label="Seed (0 for random)")
445
+
446
+ btn_img2mesh_preprocess = gr.Button("Preprocess Image")
447
+
448
+ image_caption = gr.Textbox(value="", label="Image Caption", placeholder="caption will be generated here base on your input image. You can also edit this caption", lines=4, interactive=True)
449
+
450
+ output_image2 = gr.Image(label="Generated image", interactive=False)
451
+ strength1 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.5, label="strength1")
452
+ strength2 = gr.Slider(minimum=0, maximum=1.0, step=0.01, value=0.95, label="strength2")
453
+ enable_redux = gr.Checkbox(label="enable redux", value=True)
454
+ use_controlnet = gr.Checkbox(label="use controlnet", value=True)
455
+
456
+ btn_img2mesh_main = gr.Button("Generate Mesh")
457
+
458
+ with gr.Column(scale=1):
459
+
460
+ # output_mesh2 = gr.Model3D(label="Generated Mesh", interactive=False)
461
+ output_image3 = gr.Image(label="gen save image", interactive=False)
462
+ output_video2 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
463
+ btn_download2 = gr.Button("Download Mesh")
464
+ file_output2 = gr.File()
465
+
466
+ # Image2
467
+ btn_img2mesh_preprocess.click(fn=image2mesh_preprocess_, inputs=[image, seed2], outputs=[output_image2, image_caption])
468
+
469
+ btn_img2mesh_main.click(fn=image2mesh_main_, inputs=[output_image2, image_caption, seed2, strength1, strength2, enable_redux, use_controlnet], outputs=[output_image3, output_video2])
470
+
471
+
472
+ btn_download2.click(fn=save_cached_mesh, inputs=[], outputs=file_output2)
473
+
474
+
475
+ # Button Click Events
476
+ # Text2
477
+ btn_text2detailed.click(fn=text_to_detailed, inputs=[prompt, seed1], outputs=detailed_prompt)
478
+ btn_text2img.click(fn=text_to_image, inputs=[detailed_prompt, seed1], outputs=output_image1)
479
+ btn_gen_mesh.click(fn=bundle_image_to_mesh, inputs=[output_image1,], outputs=output_video1)
480
+ # btn_gen_mesh.click(fn=bundle_image_to_mesh, inputs=[output_image1, lrm_radius, isomer_radius, reconstruction_stage1_steps, reconstruction_stage2_steps], outputs=output_video1)
481
+
482
+ with gr.Row():
483
+ pass
484
+ with gr.Row():
485
+ gr.Markdown(_CITE_)
486
+
487
+ # demo.queue(default_concurrency_limit=1)
488
+ # demo.launch(server_name="0.0.0.0", server_port=9239)
489
+ demo.launch()
490
+
491
 
492
+ if __name__ == "__main__":
493
+ main()
 
demo.py → app_demo.py RENAMED
@@ -4,11 +4,10 @@ import subprocess
4
  import shlex
5
  import spaces
6
  import torch
7
- import numpy as numpy
8
  access_token = os.getenv("HUGGINGFACE_TOKEN")
9
  subprocess.run(
10
  shlex.split(
11
- "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt210/download.html"
12
  )
13
  )
14
 
@@ -20,7 +19,7 @@ subprocess.run(
20
 
21
  subprocess.run(
22
  shlex.split(
23
- "pip install ./extension/renderutils_plugin-1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
24
  )
25
  )
26
  def install_cuda_toolkit():
@@ -41,7 +40,7 @@ def install_cuda_toolkit():
41
  # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
42
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
43
  print("==> finfish install")
44
- # install_cuda_toolkit()
45
  @spaces.GPU
46
  def check_gpu():
47
  os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
@@ -84,8 +83,8 @@ from huggingface_hub import hf_hub_download
84
 
85
  from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
86
 
87
- device_0 = "cuda:0"
88
- device_1 = "cuda:1"
89
  resolution = 512
90
  save_dir = "./outputs"
91
  normal_transfer = NormalTransfer()
@@ -97,15 +96,15 @@ isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(d
97
 
98
  # model initialization and loading
99
  # flux
100
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device_0)
101
- good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, token=access_token).to(device_0)
102
  # flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(device=device_0, dtype=torch.bfloat16)
103
- flux_pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, vae=taef1, token=access_token).to(device_0)
104
- flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model")
105
- flux_pipe.load_lora_weights(flux_lora_ckpt_path)
106
  # flux_pipe.to(device=device_0, dtype=torch.bfloat16)
107
- torch.cuda.empty_cache()
108
- flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(flux_pipe)
109
 
110
 
111
  # lrm
@@ -159,7 +158,7 @@ def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", expor
159
 
160
  all_mv, all_mvp, all_campos = get_render_cameras_video(
161
  batch_size=1,
162
- M=240,
163
  radius=4.5,
164
  elevation=(90, 60.0),
165
  is_flexicubes=True,
@@ -209,28 +208,27 @@ def generate_multi_view_images(prompt, seed):
209
  # generator = torch.manual_seed(seed)
210
  generator = torch.Generator().manual_seed(seed)
211
  with torch.no_grad():
212
- # images = flux_pipe(
 
 
 
 
 
 
 
 
 
 
213
  # prompt=prompt,
214
- # num_inference_steps=10,
215
  # guidance_scale=3.5,
216
- # num_images_per_prompt=1,
217
  # width=resolution * 4,
218
  # height=resolution * 2,
219
- # output_type='np',
220
  # generator=generator,
 
221
  # good_vae=good_vae,
222
- # ).images
223
- for img in flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images(
224
- prompt=prompt,
225
- guidance_scale=3.5,
226
- num_inference_steps=10,
227
- width=resolution * 4,
228
- height=resolution * 2,
229
- generator=generator,
230
- output_type="np",
231
- good_vae=good_vae,
232
- ):
233
- pass
234
  # 返回最终的图像和种子(通过外部调用处理)
235
  return img
236
 
@@ -251,7 +249,7 @@ def reconstruct_3d_model(images, prompt):
251
  multi_view_mask = get_background(normal_multi_view)
252
  rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
253
  input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device_1)
254
- vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False)
255
  # local normal to global normal
256
 
257
  global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
@@ -307,19 +305,81 @@ def reconstruct_3d_model(images, prompt):
307
  # Gradio 接口函数
308
  @spaces.GPU
309
  def gradio_pipeline(prompt, seed):
 
 
 
 
 
 
 
 
310
  # 生成多视图图像
311
- rgb_normal_grid = generate_multi_view_images(prompt, seed)
312
- image_preview = Image.fromarray((rgb_normal_grid * 255).astype(np.uint8))
 
313
 
314
  # 3d reconstruction
315
 
316
 
317
  # 重建 3D 模型并返回 glb 路径
318
  save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt)
319
-
320
  return image_preview, save_glb_addr
321
 
322
- if __name__ == "__main__":
323
- prompt_input = "a owm"
324
- sample_seed = 42
325
- gradio_pipeline(prompt_input, sample_seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import shlex
5
  import spaces
6
  import torch
 
7
  access_token = os.getenv("HUGGINGFACE_TOKEN")
8
  subprocess.run(
9
  shlex.split(
10
+ "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html"
11
  )
12
  )
13
 
 
19
 
20
  subprocess.run(
21
  shlex.split(
22
+ "pip install ./extension/renderutils_plugin-0.1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
23
  )
24
  )
25
  def install_cuda_toolkit():
 
40
  # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
41
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
42
  print("==> finfish install")
43
+ install_cuda_toolkit()
44
  @spaces.GPU
45
  def check_gpu():
46
  os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
 
83
 
84
  from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
85
 
86
+ device_0 = "cuda"
87
+ device_1 = "cuda"
88
  resolution = 512
89
  save_dir = "./outputs"
90
  normal_transfer = NormalTransfer()
 
96
 
97
  # model initialization and loading
98
  # flux
99
+ # # taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device_0)
100
+ # # good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, token=access_token).to(device_0)
101
  # flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(device=device_0, dtype=torch.bfloat16)
102
+ # # flux_pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, vae=taef1, token=access_token).to(device_0)
103
+ # flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
104
+ # flux_pipe.load_lora_weights(flux_lora_ckpt_path)
105
  # flux_pipe.to(device=device_0, dtype=torch.bfloat16)
106
+ # torch.cuda.empty_cache()
107
+ # flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(flux_pipe)
108
 
109
 
110
  # lrm
 
158
 
159
  all_mv, all_mvp, all_campos = get_render_cameras_video(
160
  batch_size=1,
161
+ M=24,
162
  radius=4.5,
163
  elevation=(90, 60.0),
164
  is_flexicubes=True,
 
208
  # generator = torch.manual_seed(seed)
209
  generator = torch.Generator().manual_seed(seed)
210
  with torch.no_grad():
211
+ img = flux_pipe(
212
+ prompt=prompt,
213
+ num_inference_steps=5,
214
+ guidance_scale=3.5,
215
+ num_images_per_prompt=1,
216
+ width=resolution * 2,
217
+ height=resolution * 1,
218
+ output_type='np',
219
+ generator=generator,
220
+ ).images
221
+ # for img in flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images(
222
  # prompt=prompt,
 
223
  # guidance_scale=3.5,
224
+ # num_inference_steps=4,
225
  # width=resolution * 4,
226
  # height=resolution * 2,
 
227
  # generator=generator,
228
+ # output_type="np",
229
  # good_vae=good_vae,
230
+ # ):
231
+ # pass
 
 
 
 
 
 
 
 
 
 
232
  # 返回最终的图像和种子(通过外部调用处理)
233
  return img
234
 
 
249
  multi_view_mask = get_background(normal_multi_view)
250
  rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
251
  input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device_1)
252
+ vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=True)
253
  # local normal to global normal
254
 
255
  global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
 
305
  # Gradio 接口函数
306
  @spaces.GPU
307
  def gradio_pipeline(prompt, seed):
308
+ import ctypes
309
+ # 显式加载 libnvrtc.so.12
310
+ cuda_lib_path = "/usr/local/cuda-12.1/lib64/libnvrtc.so.12"
311
+ try:
312
+ ctypes.CDLL(cuda_lib_path, mode=ctypes.RTLD_GLOBAL)
313
+ print(f"Successfully preloaded {cuda_lib_path}")
314
+ except OSError as e:
315
+ print(f"Failed to preload {cuda_lib_path}: {e}")
316
  # 生成多视图图像
317
+ # rgb_normal_grid = generate_multi_view_images(prompt, seed)
318
+ rgb_normal_grid = np.load("rgb_normal_grid.npy")
319
+ image_preview = Image.fromarray((rgb_normal_grid[0] * 255).astype(np.uint8))
320
 
321
  # 3d reconstruction
322
 
323
 
324
  # 重建 3D 模型并返回 glb 路径
325
  save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt)
326
+ # save_glb_addr = None
327
  return image_preview, save_glb_addr
328
 
329
+ # Gradio Blocks 应用
330
+ with gr.Blocks() as demo:
331
+ with gr.Row(variant="panel"):
332
+ # 左侧输入区域
333
+ with gr.Column():
334
+ with gr.Row():
335
+ prompt_input = gr.Textbox(
336
+ label="Enter Prompt",
337
+ placeholder="Describe your 3D model...",
338
+ lines=2,
339
+ elem_id="prompt_input"
340
+ )
341
+
342
+ with gr.Row():
343
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
344
+
345
+ with gr.Row():
346
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
347
+
348
+ with gr.Row(variant="panel"):
349
+ gr.Markdown("Examples:")
350
+ gr.Examples(
351
+ examples=[
352
+ ["a castle on a hill"],
353
+ ["an owl wearing a hat"],
354
+ ["a futuristic car"]
355
+ ],
356
+ inputs=[prompt_input],
357
+ label="Prompt Examples"
358
+ )
359
+
360
+ # 右侧输出区域
361
+ with gr.Column():
362
+ with gr.Row():
363
+ rgb_normal_grid_image = gr.Image(
364
+ label="RGB Normal Grid",
365
+ type="pil",
366
+ interactive=False
367
+ )
368
+
369
+ with gr.Row():
370
+ with gr.Tab("GLB"):
371
+ output_glb_model = gr.Model3D(
372
+ label="Generated 3D Model (GLB Format)",
373
+ interactive=False
374
+ )
375
+ gr.Markdown("Download the model for proper visualization.")
376
+
377
+ # 处理逻辑
378
+ submit.click(
379
+ fn=gradio_pipeline, inputs=[prompt_input, sample_seed],
380
+ outputs=[rgb_normal_grid_image, output_glb_model]
381
+ )
382
+
383
+ # 启动应用
384
+ # demo.queue(max_size=10)
385
+ demo.launch()
app_flux.py DELETED
@@ -1,141 +0,0 @@
1
- import gradio as gr
2
- import numpy as np
3
- import os
4
- import random
5
- import spaces
6
- import torch
7
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
8
- from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
9
- from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
10
-
11
- dtype = torch.bfloat16
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- access_token = os.getenv("HUGGINGFACE_TOKEN")
14
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
15
- good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype, token=access_token).to(device)
16
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1, token=access_token).to(device)
17
- torch.cuda.empty_cache()
18
-
19
- MAX_SEED = np.iinfo(np.int32).max
20
- MAX_IMAGE_SIZE = 2048
21
-
22
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
23
-
24
- @spaces.GPU(duration=75)
25
- def infer(prompt, seed=42, randomize_seed=False, width=2048, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
26
- if randomize_seed:
27
- seed = random.randint(0, MAX_SEED)
28
- generator = torch.Generator().manual_seed(seed)
29
-
30
- for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
31
- prompt=prompt,
32
- guidance_scale=guidance_scale,
33
- num_inference_steps=num_inference_steps,
34
- width=width,
35
- height=height,
36
- generator=generator,
37
- output_type="pil",
38
- good_vae=good_vae,
39
- ):
40
- # yield img, seed
41
- pass
42
- return img, seed
43
- examples = [
44
- "a tiny astronaut hatching from an egg on the moon",
45
- "a cat holding a sign that says hello world",
46
- "an anime illustration of a wiener schnitzel",
47
- ]
48
-
49
- css="""
50
- #col-container {
51
- margin: 0 auto;
52
- max-width: 520px;
53
- }
54
- """
55
-
56
- with gr.Blocks(css=css) as demo:
57
-
58
- with gr.Column(elem_id="col-container"):
59
- gr.Markdown(f"""# FLUX.1 [dev]
60
- 12B param rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
61
- [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
62
- """)
63
-
64
- with gr.Row():
65
-
66
- prompt = gr.Text(
67
- label="Prompt",
68
- show_label=False,
69
- max_lines=1,
70
- placeholder="Enter your prompt",
71
- container=False,
72
- )
73
-
74
- run_button = gr.Button("Run", scale=0)
75
-
76
- result = gr.Image(label="Result", show_label=False)
77
-
78
- with gr.Accordion("Advanced Settings", open=False):
79
-
80
- seed = gr.Slider(
81
- label="Seed",
82
- minimum=0,
83
- maximum=MAX_SEED,
84
- step=1,
85
- value=0,
86
- )
87
-
88
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
89
-
90
- with gr.Row():
91
-
92
- width = gr.Slider(
93
- label="Width",
94
- minimum=256,
95
- maximum=MAX_IMAGE_SIZE,
96
- step=32,
97
- value=1024,
98
- )
99
-
100
- height = gr.Slider(
101
- label="Height",
102
- minimum=256,
103
- maximum=MAX_IMAGE_SIZE,
104
- step=32,
105
- value=1024,
106
- )
107
-
108
- with gr.Row():
109
-
110
- guidance_scale = gr.Slider(
111
- label="Guidance Scale",
112
- minimum=1,
113
- maximum=15,
114
- step=0.1,
115
- value=3.5,
116
- )
117
-
118
- num_inference_steps = gr.Slider(
119
- label="Number of inference steps",
120
- minimum=1,
121
- maximum=50,
122
- step=1,
123
- value=28,
124
- )
125
-
126
- gr.Examples(
127
- examples = examples,
128
- fn = infer,
129
- inputs = [prompt],
130
- outputs = [result, seed],
131
- cache_examples="lazy"
132
- )
133
-
134
- gr.on(
135
- triggers=[run_button.click, prompt.submit],
136
- fn = infer,
137
- inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
138
- outputs = [result, seed]
139
- )
140
-
141
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
image_to_mesh.py DELETED
@@ -1,437 +0,0 @@
1
- import os
2
- from einops import rearrange
3
- from omegaconf import OmegaConf
4
- import torch
5
- import numpy as np
6
- import trimesh
7
- import torchvision
8
- import torch.nn.functional as F
9
- from PIL import Image
10
- from torchvision import transforms
11
- from torchvision.transforms import v2
12
- from transformers import AutoProcessor, AutoModelForCausalLM
13
- import rembg
14
- from diffusers import FluxPipeline, FluxControlNetImg2ImgPipeline
15
- from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
16
- from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler
17
- from pytorch_lightning import seed_everything
18
- import os
19
-
20
- from models.ISOMER.reconstruction_func import reconstruction
21
- from models.ISOMER.projection_func import projection
22
- from models.lrm.utils.infer_util import remove_background, resize_foreground, save_video
23
- from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
24
- from models.lrm.utils.render_utils import rotate_x, rotate_y
25
- from models.lrm.utils.train_util import instantiate_from_config
26
- from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
27
- from utils.tool import NormalTransfer, get_render_cameras_frames, load_mipmap
28
- from utils.tool import get_background, get_render_cameras_video, render_frames
29
- import time
30
-
31
- device = "cuda"
32
- resolution = 512
33
- save_dir = "./outputs"
34
- zero123plus_diffusion_steps = 75
35
- normal_transfer = NormalTransfer()
36
- rembg_session = rembg.new_session()
37
- isomer_azimuths = torch.from_numpy(np.array([270, 0, 90, 180])).to(device)
38
- isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).to(device)
39
- isomer_radius = 4.1
40
- isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
41
- isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
42
- # seed_everything(42)
43
-
44
- # model initialization and loading
45
- # flux
46
- print('==> Loading Flux model ...')
47
- flux_base_model_pth = "/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev"
48
- flux_controlnet = FluxControlNetModel.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/flux_controlnets/FLUX.1-dev-ControlNet-Union-Pro")
49
- flux_pipe = FluxControlNetImg2ImgPipeline.from_pretrained(flux_base_model_pth, controlnet=[flux_controlnet], torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
50
-
51
- flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors')
52
-
53
-
54
- flux_pipe.to(device=device, dtype=torch.bfloat16)
55
- generator = torch.Generator(device=device).manual_seed(0)
56
-
57
- # lrm
58
- print('==> Loading LRM model ...')
59
- config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
60
- model_config = config.model_config
61
- infer_config = config.infer_config
62
- model = instantiate_from_config(model_config)
63
- model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt"
64
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
65
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
66
- model.load_state_dict(state_dict, strict=True)
67
-
68
- model = model.to(device)
69
- model.init_flexicubes_geometry(device, fovy=50.0)
70
- model = model.eval()
71
-
72
- # zero123++
73
- print('==> Loading diffusion model ...')
74
- zero123plus_pipeline = DiffusionPipeline.from_pretrained(
75
- "sudo-ai/zero123plus-v1.2",
76
- custom_pipeline="./models/zero123plus",
77
- torch_dtype=torch.float16,
78
- )
79
- zero123plus_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
80
- zero123plus_pipeline.scheduler.config, timestep_spacing='trailing'
81
- )
82
- unet_ckpt_path = "./checkpoint/zero123++/flexgen_19w.ckpt"
83
- state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
84
- state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
85
- zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
86
- zero123plus_pipeline = zero123plus_pipeline.to(device)
87
-
88
- # unet_ckpt_path = "checkpoint/zero123++/diffusion_pytorch_model.bin"
89
- # state_dict = torch.load(unet_ckpt_path, map_location='cpu')
90
- # zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
91
- # zero123plus_pipeline = zero123plus_pipeline.to(device)
92
-
93
- # florence
94
- caption_model = AutoModelForCausalLM.from_pretrained(
95
- "/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", torch_dtype=torch.bfloat16, trust_remote_code=True,
96
- ).to(device)
97
- caption_processor = AutoProcessor.from_pretrained("/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", trust_remote_code=True)
98
-
99
- # Flux multi-view generation
100
- def multi_view_rgb_normal_generation_with_controlnet(prompt, image, strength=1.0,
101
- control_image=[],
102
- control_mode=[],
103
- control_guidance_start=None,
104
- control_guidance_end=None,
105
- controlnet_conditioning_scale=None,
106
- lora_scale=1.0
107
- ):
108
- control_mode_dict = {
109
- 'canny': 0,
110
- 'tile': 1,
111
- 'depth': 2,
112
- 'blur': 3,
113
- 'pose': 4,
114
- 'gray': 5,
115
- 'lq': 6,
116
- } # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
117
-
118
- hparam_dict = {
119
- 'prompt': prompt,
120
- 'image': image,
121
- 'strength': strength,
122
- 'num_inference_steps': 30,
123
- 'guidance_scale': 3.5,
124
- 'num_images_per_prompt': 1,
125
- 'width': resolution*4,
126
- 'height': resolution*2,
127
- 'output_type': 'np',
128
- 'generator': generator,
129
- 'joint_attention_kwargs': {"scale": lora_scale}
130
- }
131
-
132
- # append controlnet hparams
133
- if len(control_image) > 0:
134
- assert len(control_mode) == len(control_image) # the count of image should be the same as control mode
135
-
136
- ctrl_hparams = {
137
- 'control_mode': [control_mode_dict[mode_] for mode_ in control_mode],
138
- 'control_image': control_image,
139
- 'control_guidance_start': control_guidance_start or [0.0 for i in range(len(control_image))],
140
- 'control_guidance_end': control_guidance_end or [1.0 for i in range(len(control_image))],
141
- 'controlnet_conditioning_scale': controlnet_conditioning_scale or [1.0 for i in range(len(control_image))],
142
- }
143
-
144
- hparam_dict.update(ctrl_hparams)
145
-
146
- # generate multi-view images
147
- with torch.no_grad():
148
- image = flux_pipe(
149
- **hparam_dict
150
- ).images
151
- return image
152
-
153
- # captioning
154
- def run_captioning(image):
155
- device = "cuda" if torch.cuda.is_available() else "cpu"
156
- torch_dtype = torch.bfloat16
157
-
158
- if isinstance(image, str): # If image is a file path
159
- image = Image.open(image).convert("RGB")
160
-
161
- prompt = "<MORE_DETAILED_CAPTION>"
162
- inputs = caption_processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
163
- # print(f"inputs {inputs}")
164
-
165
- generated_ids = caption_model.generate(
166
- input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
167
- )
168
-
169
- generated_text = caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
170
- parsed_answer = caption_processor.post_process_generation(
171
- generated_text, task=prompt, image_size=(image.width, image.height)
172
- )
173
- # print(f"parsed_answer = {parsed_answer}")
174
- caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"].replace("The image is ", "")
175
- return caption_text
176
-
177
-
178
- # zero123++ multi-view generation
179
- def multi_view_rgb_generation(cond_img):
180
- # generate multi-view images
181
- with torch.no_grad():
182
- output_image = zero123plus_pipeline(
183
- cond_img,
184
- num_inference_steps=zero123plus_diffusion_steps,
185
- width=resolution*2,
186
- height=resolution*2,
187
- ).images[0]
188
- return output_image
189
-
190
- # lrm reconstructions
191
- def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False, render_azimuths=None, render_elevations=None, render_radius=None, render_fov=30):
192
- images = image.unsqueeze(0).to(device)
193
- images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
194
- # breakpoint()
195
- with torch.no_grad():
196
- # get triplane
197
- planes = model.forward_planes(images, input_cameras)
198
-
199
- mesh_path_idx = os.path.join(save_path, f'{name}.obj')
200
-
201
- mesh_out = model.extract_mesh(
202
- planes,
203
- use_texture_map=export_texmap,
204
- **infer_config,
205
- )
206
- if export_texmap:
207
- vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
208
- save_obj_with_mtl(
209
- vertices.data.cpu().numpy(),
210
- uvs.data.cpu().numpy(),
211
- faces.data.cpu().numpy(),
212
- mesh_tex_idx.data.cpu().numpy(),
213
- tex_map.permute(1, 2, 0).data.cpu().numpy(),
214
- mesh_path_idx,
215
- )
216
- else:
217
- vertices, faces, vertex_colors = mesh_out
218
- save_obj(vertices, faces, vertex_colors, mesh_path_idx)
219
- print(f"Mesh saved to {mesh_path_idx}")
220
-
221
- render_size = 512
222
- if if_save_video:
223
- video_path_idx = os.path.join(save_path, f'{name}.mp4')
224
- render_size = infer_config.render_resolution
225
- ENV = load_mipmap("models/lrm/env_mipmap/6")
226
- materials = (0.0,0.9)
227
-
228
- all_mv, all_mvp, all_campos = get_render_cameras_video(
229
- batch_size=1,
230
- M=240,
231
- radius=4.5,
232
- elevation=(90, 60.0),
233
- is_flexicubes=True,
234
- fov=30
235
- )
236
-
237
- frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
238
- model,
239
- planes,
240
- render_cameras=all_mvp,
241
- camera_pos=all_campos,
242
- env=ENV,
243
- materials=materials,
244
- render_size=render_size,
245
- chunk_size=20,
246
- is_flexicubes=True,
247
- )
248
- normals = (torch.nn.functional.normalize(normals) + 1) / 2
249
- normals = normals * alphas + (1-alphas)
250
- all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
251
-
252
- # breakpoint()
253
- save_video(
254
- all_frames,
255
- video_path_idx,
256
- fps=30,
257
- )
258
- print(f"Video saved to {video_path_idx}")
259
-
260
- if render_azimuths is not None and render_elevations is not None and render_radius is not None:
261
- render_size = infer_config.render_resolution
262
- ENV = load_mipmap("models/lrm/env_mipmap/6")
263
- materials = (0.0,0.9)
264
- all_mv, all_mvp, all_campos, identity_mv = get_render_cameras_frames(
265
- batch_size=1,
266
- radius=render_radius,
267
- azimuths=render_azimuths,
268
- elevations=render_elevations,
269
- fov=30
270
- )
271
- frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
272
- model,
273
- planes,
274
- render_cameras=all_mvp,
275
- camera_pos=all_campos,
276
- env=ENV,
277
- materials=materials,
278
- render_size=render_size,
279
- render_mv = all_mv,
280
- local_normal=True,
281
- identity_mv=identity_mv,
282
- )
283
- else:
284
- normals = None
285
- frames = None
286
- albedos = None
287
-
288
- return vertices, faces, normals, frames, albedos
289
-
290
-
291
- def transform_normal(input_normal, azimuths_deg, elevations_deg, radius=4.5, is_global_to_local=False):
292
- """
293
- input_normal: in range [-1, 1], shape (b c h w)
294
- """
295
-
296
- input_normal = input_normal.permute(0, 2, 3, 1).cpu()
297
-
298
- azimuths_deg = np.array(azimuths_deg)
299
- elevations_deg = np.array(elevations_deg)
300
-
301
- if is_global_to_local:
302
- local_normal = normal_transfer.trans_global_2_local(input_normal, azimuths_deg, elevations_deg)
303
- return local_normal.permute(0, 3, 1, 2)
304
- else:
305
- global_normal = normal_transfer.trans_local_2_global(input_normal, azimuths_deg, elevations_deg, radius=radius, for_lotus=False)
306
- global_normal[..., 0] *= -1
307
- return global_normal.permute(0, 3, 1, 2)
308
-
309
- def local_normal_global_transform(local_normal_images,azimuths_deg,elevations_deg):
310
- if local_normal_images.min() >= 0:
311
- local_normal = local_normal_images.float() * 2 - 1
312
- else:
313
- local_normal = local_normal_images.float()
314
- global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
315
- global_normal[...,0] *= -1
316
- global_normal = (global_normal + 1) / 2
317
- global_normal = global_normal.permute(0, 3, 1, 2)
318
- return global_normal
319
-
320
- def main():
321
- image_pth = "examples/蓝色小怪物.webp"
322
- save_dir_path = os.path.join(save_dir, image_pth.split("/")[-1].split(".")[0])
323
- os.makedirs(save_dir_path, exist_ok=True)
324
- input_image = Image.open(image_pth)
325
- # if not args.no_rembg:
326
- input_image = remove_background(input_image, rembg_session)
327
- input_image = resize_foreground(input_image, 0.85)
328
-
329
- # generate caption
330
- image_caption = run_captioning(image_pth)
331
-
332
- # generate multi-view images
333
- output_image = multi_view_rgb_generation(input_image)
334
-
335
- # lrm reconstructions
336
- rgb_multi_view = np.asarray(output_image, dtype=np.float32) / 255.0
337
- rgb_multi_view = torch.from_numpy(rgb_multi_view).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
338
- rgb_multi_view = rearrange(rgb_multi_view, 'c (n h) (m w) -> (n m) c h w', n=2, m=2) # (8, 3, 512, 512)
339
-
340
- input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
341
-
342
- vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \
343
- lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm',
344
- export_texmap=False, if_save_video=False, render_azimuths=isomer_azimuths,
345
- render_elevations=isomer_elevations, render_radius=isomer_radius, render_fov=30)
346
-
347
- vertices = torch.from_numpy(vertices).to(device)
348
- faces = torch.from_numpy(faces).to(device)
349
- vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
350
- vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
351
-
352
-
353
- # lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
354
- lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([rgb_multi_view[[3,0,1,2]].cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
355
- # rgb_multi_view[[3,0,1,2]] : (B,3,H,W)
356
- # lrm_multi_view_normals : (B,3,H,W)
357
- # combined_images = 0.5 * rgb_multi_view[[3,0,1,2]].cpu() + 0.5 * (lrm_multi_view_normals.cpu() + 1) / 2
358
- # torchvision.utils.save_image(combined_images, os.path.join("debug_output", 'combined.png'))
359
- # breakpoint()
360
- # Use the low-quality controlnet by default, feel free to try the others
361
- control_image = [lrm_3D_bundle_image * 2 - 1]
362
- control_mode = ['tile']
363
- control_guidance_start = [0.0]
364
- control_guidance_end = [0.3]
365
- controlnet_conditioning_scale = [0.8]
366
-
367
- flux_pipe.controlnet = FluxMultiControlNetModel([flux_controlnet for _ in control_mode])
368
- # breakpoint()
369
- rgb_normal_grid = multi_view_rgb_normal_generation_with_controlnet(
370
- prompt= ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', image_caption]),
371
- image=lrm_3D_bundle_image,
372
- strength=0.6,
373
- control_image=control_image,
374
- control_mode=control_mode,
375
- control_guidance_start=control_guidance_start,
376
- control_guidance_end=control_guidance_end,
377
- controlnet_conditioning_scale=controlnet_conditioning_scale,
378
- lora_scale=1.0
379
- ) # noted that rgb_normal_grid is a (b, h, w, c) numpy array
380
-
381
- rgb_normal_grid = torch.from_numpy(rgb_normal_grid).contiguous().float()
382
- rgb_normal_grid = rearrange(rgb_normal_grid.squeeze(0), '(n h) (m w) c-> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
383
- rgb_multi_view = rgb_normal_grid[:4, :3, :, :].cuda()
384
- normal_multi_view = rgb_normal_grid[4:, :3, :, :].cuda()
385
- multi_view_mask = get_background(normal_multi_view).cuda()
386
- rgb_multi_view = rgb_multi_view * multi_view_mask + (1-multi_view_mask)
387
-
388
- # local normal to global normal
389
- global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1).cpu(), isomer_azimuths, isomer_elevations).cuda()
390
-
391
- global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
392
-
393
- global_normal = global_normal.permute(0,2,3,1)
394
- multi_view_mask = multi_view_mask.squeeze(1)
395
- rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
396
- # global_normal: B,H,W,3
397
- # multi_view_mask: B,H,W
398
- # rgb_multi_view: B,H,W,3
399
-
400
-
401
- meshes = reconstruction(
402
- normal_pils=global_normal,
403
- masks=multi_view_mask,
404
- weights=isomer_geo_weights,
405
- fov=30,
406
- radius=isomer_radius,
407
- camera_angles_azi=isomer_azimuths,
408
- camera_angles_ele=isomer_elevations,
409
- expansion_weight_stage1=0.1,
410
- init_type="file",
411
- init_verts=vertices,
412
- init_faces=faces,
413
- stage1_steps=0,
414
- stage2_steps=50,
415
- start_edge_len_stage1=0.1,
416
- end_edge_len_stage1=0.02,
417
- start_edge_len_stage2=0.02,
418
- end_edge_len_stage2=0.005,
419
- )
420
-
421
- save_glb_addr = projection(
422
- meshes=meshes,
423
- masks=multi_view_mask,
424
- images=rgb_multi_view,
425
- azimuths=isomer_azimuths,
426
- elevations=isomer_elevations,
427
- weights=isomer_color_weights,
428
- fov=30,
429
- radius=isomer_radius,
430
- save_dir=f"{save_dir_path}/ISOMER/",
431
- )
432
- print(f'saved to {save_glb_addr}')
433
-
434
-
435
-
436
- if __name__ == '__main__':
437
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
image_to_mesh_new.py DELETED
@@ -1,436 +0,0 @@
1
- import os
2
- from einops import rearrange
3
- from omegaconf import OmegaConf
4
- import torch
5
- import numpy as np
6
- import trimesh
7
- import torchvision
8
- import torch.nn.functional as F
9
- from PIL import Image
10
- from torchvision import transforms
11
- from torchvision.transforms import v2
12
- from transformers import AutoProcessor, AutoModelForCausalLM
13
- import rembg
14
- from diffusers import FluxPipeline, FluxControlNetImg2ImgPipeline
15
- from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
16
- from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler
17
- from pytorch_lightning import seed_everything
18
- import os
19
-
20
- from models.ISOMER.reconstruction_func import reconstruction
21
- from models.ISOMER.projection_func import projection
22
- from models.lrm.utils.infer_util import remove_background, resize_foreground, save_video
23
- from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
24
- from models.lrm.utils.render_utils import rotate_x, rotate_y
25
- from models.lrm.utils.train_util import instantiate_from_config
26
- from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
27
- from utils.tool import NormalTransfer, get_render_cameras_frames, load_mipmap
28
- from utils.tool import get_background, get_render_cameras_video, render_frames, mask_fix
29
-
30
- device = "cuda"
31
- resolution = 512
32
- save_dir = "./outputs"
33
- zero123plus_diffusion_steps = 75
34
- normal_transfer = NormalTransfer()
35
- rembg_session = rembg.new_session()
36
- isomer_azimuths = torch.from_numpy(np.array([270, 0, 90, 180])).to(device)
37
- isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).to(device)
38
- isomer_radius = 4.1
39
- isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
40
- isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
41
- # seed_everything(42)
42
-
43
- # model initialization and loading
44
- # flux
45
- print('==> Loading Flux model ...')
46
- flux_base_model_pth = "/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev"
47
- flux_controlnet = FluxControlNetModel.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/flux_controlnets/FLUX.1-dev-ControlNet-Union-Pro")
48
- flux_pipe = FluxControlNetImg2ImgPipeline.from_pretrained(flux_base_model_pth, controlnet=[flux_controlnet], torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
49
-
50
- flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors')
51
-
52
-
53
- flux_pipe.to(device=device, dtype=torch.bfloat16)
54
- generator = torch.Generator(device=device).manual_seed(0)
55
-
56
- # lrm
57
- print('==> Loading LRM model ...')
58
- config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
59
- model_config = config.model_config
60
- infer_config = config.infer_config
61
- model = instantiate_from_config(model_config)
62
- model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt"
63
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
64
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
65
- model.load_state_dict(state_dict, strict=True)
66
-
67
- model = model.to(device)
68
- model.init_flexicubes_geometry(device, fovy=50.0)
69
- model = model.eval()
70
-
71
- # zero123++
72
- print('==> Loading diffusion model ...')
73
- zero123plus_pipeline = DiffusionPipeline.from_pretrained(
74
- "sudo-ai/zero123plus-v1.2",
75
- custom_pipeline="./models/zero123plus",
76
- torch_dtype=torch.float16,
77
- )
78
- zero123plus_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
79
- zero123plus_pipeline.scheduler.config, timestep_spacing='trailing'
80
- )
81
- unet_ckpt_path = "./checkpoint/zero123++/flexgen_19w.ckpt"
82
- state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
83
- state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
84
- zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
85
- zero123plus_pipeline = zero123plus_pipeline.to(device)
86
-
87
- # unet_ckpt_path = "checkpoint/zero123++/diffusion_pytorch_model.bin"
88
- # state_dict = torch.load(unet_ckpt_path, map_location='cpu')
89
- # zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
90
- # zero123plus_pipeline = zero123plus_pipeline.to(device)
91
-
92
- # florence
93
- caption_model = AutoModelForCausalLM.from_pretrained(
94
- "/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", torch_dtype=torch.bfloat16, trust_remote_code=True,
95
- ).to(device)
96
- caption_processor = AutoProcessor.from_pretrained("/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", trust_remote_code=True)
97
-
98
- # Flux multi-view generation
99
- def multi_view_rgb_normal_generation_with_controlnet(prompt, image, strength=1.0,
100
- control_image=[],
101
- control_mode=[],
102
- control_guidance_start=None,
103
- control_guidance_end=None,
104
- controlnet_conditioning_scale=None,
105
- lora_scale=1.0
106
- ):
107
- control_mode_dict = {
108
- 'canny': 0,
109
- 'tile': 1,
110
- 'depth': 2,
111
- 'blur': 3,
112
- 'pose': 4,
113
- 'gray': 5,
114
- 'lq': 6,
115
- } # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
116
-
117
- hparam_dict = {
118
- 'prompt': prompt,
119
- 'image': image,
120
- 'strength': strength,
121
- 'num_inference_steps': 30,
122
- 'guidance_scale': 3.5,
123
- 'num_images_per_prompt': 1,
124
- 'width': resolution*4,
125
- 'height': resolution*2,
126
- 'output_type': 'np',
127
- 'generator': generator,
128
- 'joint_attention_kwargs': {"scale": lora_scale}
129
- }
130
-
131
- # append controlnet hparams
132
- if len(control_image) > 0:
133
- assert len(control_mode) == len(control_image) # the count of image should be the same as control mode
134
-
135
- ctrl_hparams = {
136
- 'control_mode': [control_mode_dict[mode_] for mode_ in control_mode],
137
- 'control_image': control_image,
138
- 'control_guidance_start': control_guidance_start or [0.0 for i in range(len(control_image))],
139
- 'control_guidance_end': control_guidance_end or [1.0 for i in range(len(control_image))],
140
- 'controlnet_conditioning_scale': controlnet_conditioning_scale or [1.0 for i in range(len(control_image))],
141
- }
142
-
143
- hparam_dict.update(ctrl_hparams)
144
-
145
- # generate multi-view images
146
- with torch.no_grad():
147
- image = flux_pipe(
148
- **hparam_dict
149
- ).images
150
- return image
151
-
152
- # captioning
153
- def run_captioning(image):
154
- device = "cuda" if torch.cuda.is_available() else "cpu"
155
- torch_dtype = torch.bfloat16
156
-
157
- if isinstance(image, str): # If image is a file path
158
- image = Image.open(image).convert("RGB")
159
-
160
- prompt = "<MORE_DETAILED_CAPTION>"
161
- inputs = caption_processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
162
- # print(f"inputs {inputs}")
163
-
164
- generated_ids = caption_model.generate(
165
- input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
166
- )
167
-
168
- generated_text = caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
169
- parsed_answer = caption_processor.post_process_generation(
170
- generated_text, task=prompt, image_size=(image.width, image.height)
171
- )
172
- # print(f"parsed_answer = {parsed_answer}")
173
- caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"].replace("The image is ", "")
174
- return caption_text
175
-
176
-
177
- # zero123++ multi-view generation
178
- def multi_view_rgb_generation(cond_img):
179
- # generate multi-view images
180
- with torch.no_grad():
181
- output_image = zero123plus_pipeline(
182
- cond_img,
183
- num_inference_steps=zero123plus_diffusion_steps,
184
- width=resolution*2,
185
- height=resolution*2,
186
- ).images[0]
187
- return output_image
188
-
189
- # lrm reconstructions
190
- def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False, render_azimuths=None, render_elevations=None, render_radius=None, render_fov=30):
191
- images = image.unsqueeze(0).to(device)
192
- images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
193
- # breakpoint()
194
- with torch.no_grad():
195
- # get triplane
196
- planes = model.forward_planes(images, input_cameras)
197
-
198
- mesh_path_idx = os.path.join(save_path, f'{name}.obj')
199
-
200
- mesh_out = model.extract_mesh(
201
- planes,
202
- use_texture_map=export_texmap,
203
- **infer_config,
204
- )
205
- if export_texmap:
206
- vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
207
- save_obj_with_mtl(
208
- vertices.data.cpu().numpy(),
209
- uvs.data.cpu().numpy(),
210
- faces.data.cpu().numpy(),
211
- mesh_tex_idx.data.cpu().numpy(),
212
- tex_map.permute(1, 2, 0).data.cpu().numpy(),
213
- mesh_path_idx,
214
- )
215
- else:
216
- vertices, faces, vertex_colors = mesh_out
217
- save_obj(vertices, faces, vertex_colors, mesh_path_idx)
218
- print(f"Mesh saved to {mesh_path_idx}")
219
-
220
- render_size = 512
221
- if if_save_video:
222
- video_path_idx = os.path.join(save_path, f'{name}.mp4')
223
- render_size = infer_config.render_resolution
224
- ENV = load_mipmap("models/lrm/env_mipmap/6")
225
- materials = (0.0,0.9)
226
-
227
- all_mv, all_mvp, all_campos = get_render_cameras_video(
228
- batch_size=1,
229
- M=240,
230
- radius=4.5,
231
- elevation=(90, 60.0),
232
- is_flexicubes=True,
233
- fov=30
234
- )
235
-
236
- frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
237
- model,
238
- planes,
239
- render_cameras=all_mvp,
240
- camera_pos=all_campos,
241
- env=ENV,
242
- materials=materials,
243
- render_size=render_size,
244
- chunk_size=20,
245
- is_flexicubes=True,
246
- )
247
- normals = (torch.nn.functional.normalize(normals) + 1) / 2
248
- normals = normals * alphas + (1-alphas)
249
- all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
250
-
251
- # breakpoint()
252
- save_video(
253
- all_frames,
254
- video_path_idx,
255
- fps=30,
256
- )
257
- print(f"Video saved to {video_path_idx}")
258
-
259
- if render_azimuths is not None and render_elevations is not None and render_radius is not None:
260
- render_size = infer_config.render_resolution
261
- ENV = load_mipmap("models/lrm/env_mipmap/6")
262
- materials = (0.0,0.9)
263
- all_mv, all_mvp, all_campos, identity_mv = get_render_cameras_frames(
264
- batch_size=1,
265
- radius=render_radius,
266
- azimuths=render_azimuths,
267
- elevations=render_elevations,
268
- fov=30
269
- )
270
- frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
271
- model,
272
- planes,
273
- render_cameras=all_mvp,
274
- camera_pos=all_campos,
275
- env=ENV,
276
- materials=materials,
277
- render_size=render_size,
278
- render_mv = all_mv,
279
- local_normal=True,
280
- identity_mv=identity_mv,
281
- )
282
- else:
283
- normals = None
284
- frames = None
285
- albedos = None
286
-
287
- return vertices, faces, normals, frames, albedos
288
-
289
-
290
- def transform_normal(input_normal, azimuths_deg, elevations_deg, radius=4.5, is_global_to_local=False):
291
- """
292
- input_normal: in range [-1, 1], shape (b c h w)
293
- """
294
-
295
- input_normal = input_normal.permute(0, 2, 3, 1).cpu()
296
-
297
- azimuths_deg = np.array(azimuths_deg)
298
- elevations_deg = np.array(elevations_deg)
299
-
300
- if is_global_to_local:
301
- local_normal = normal_transfer.trans_global_2_local(input_normal, azimuths_deg, elevations_deg)
302
- return local_normal.permute(0, 3, 1, 2)
303
- else:
304
- global_normal = normal_transfer.trans_local_2_global(input_normal, azimuths_deg, elevations_deg, radius=radius, for_lotus=False)
305
- global_normal[..., 0] *= -1
306
- return global_normal.permute(0, 3, 1, 2)
307
-
308
- def local_normal_global_transform(local_normal_images,azimuths_deg,elevations_deg):
309
- if local_normal_images.min() >= 0:
310
- local_normal = local_normal_images.float() * 2 - 1
311
- else:
312
- local_normal = local_normal_images.float()
313
- global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
314
- global_normal[...,0] *= -1
315
- global_normal = (global_normal + 1) / 2
316
- global_normal = global_normal.permute(0, 3, 1, 2)
317
- return global_normal
318
-
319
- def main():
320
- image_pth = "examples/蓝色小怪物.webp"
321
- save_dir_path = os.path.join(save_dir, image_pth.split("/")[-1].split(".")[0])
322
- os.makedirs(save_dir_path, exist_ok=True)
323
- input_image = Image.open(image_pth)
324
- # if not args.no_rembg:
325
- input_image = remove_background(input_image, rembg_session)
326
- input_image = resize_foreground(input_image, 0.85)
327
-
328
- # generate caption
329
- image_caption = run_captioning(image_pth)
330
-
331
- # generate multi-view images
332
- output_image = multi_view_rgb_generation(input_image)
333
-
334
- # lrm reconstructions
335
- rgb_multi_view = np.asarray(output_image, dtype=np.float32) / 255.0
336
- rgb_multi_view = torch.from_numpy(rgb_multi_view).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
337
- rgb_multi_view = rearrange(rgb_multi_view, 'c (n h) (m w) -> (n m) c h w', n=2, m=2) # (8, 3, 512, 512)
338
-
339
- input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
340
-
341
- vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \
342
- lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm',
343
- export_texmap=False, if_save_video=False, render_azimuths=isomer_azimuths,
344
- render_elevations=isomer_elevations, render_radius=isomer_radius, render_fov=30)
345
-
346
- vertices = torch.from_numpy(vertices).to(device)
347
- faces = torch.from_numpy(faces).to(device)
348
- vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
349
- vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
350
-
351
-
352
- # lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
353
- lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([rgb_multi_view[[3,0,1,2]].cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
354
- # rgb_multi_view[[3,0,1,2]] : (B,3,H,W)
355
- # lrm_multi_view_normals : (B,3,H,W)
356
- # combined_images = 0.5 * rgb_multi_view[[3,0,1,2]].cpu() + 0.5 * (lrm_multi_view_normals.cpu() + 1) / 2
357
- # torchvision.utils.save_image(combined_images, os.path.join("debug_output", 'combined.png'))
358
- # breakpoint()
359
- # Use the low-quality controlnet by default, feel free to try the others
360
- control_image = [lrm_3D_bundle_image * 2 - 1]
361
- control_mode = ['tile']
362
- control_guidance_start = [0.0]
363
- control_guidance_end = [0.3]
364
- controlnet_conditioning_scale = [0.8]
365
-
366
- flux_pipe.controlnet = FluxMultiControlNetModel([flux_controlnet for _ in control_mode])
367
- # breakpoint()
368
- rgb_normal_grid = multi_view_rgb_normal_generation_with_controlnet(
369
- prompt= ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', image_caption]),
370
- image=lrm_3D_bundle_image,
371
- strength=0.6,
372
- control_image=control_image,
373
- control_mode=control_mode,
374
- control_guidance_start=control_guidance_start,
375
- control_guidance_end=control_guidance_end,
376
- controlnet_conditioning_scale=controlnet_conditioning_scale,
377
- lora_scale=1.0
378
- ) # noted that rgb_normal_grid is a (b, h, w, c) numpy array
379
-
380
- rgb_normal_grid = torch.from_numpy(rgb_normal_grid).contiguous().float()
381
- rgb_normal_grid = rearrange(rgb_normal_grid.squeeze(0), '(n h) (m w) c-> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
382
- rgb_multi_view = rgb_normal_grid[:4, :3, :, :].cuda()
383
- normal_multi_view = rgb_normal_grid[4:, :3, :, :].cuda()
384
- multi_view_mask = get_background(normal_multi_view).cuda()
385
- rgb_multi_view = rgb_multi_view * multi_view_mask + (1-multi_view_mask)
386
-
387
- # local normal to global normal
388
- global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1).cpu(), isomer_azimuths, isomer_elevations).cuda()
389
-
390
- global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
391
-
392
- global_normal = global_normal.permute(0,2,3,1)
393
- multi_view_mask = multi_view_mask.squeeze(1)
394
- rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
395
- # global_normal: B,H,W,3
396
- # multi_view_mask: B,H,W
397
- # rgb_multi_view: B,H,W,3
398
-
399
-
400
- meshes = reconstruction(
401
- normal_pils=global_normal,
402
- masks=multi_view_mask,
403
- weights=isomer_geo_weights,
404
- fov=30,
405
- radius=isomer_radius,
406
- camera_angles_azi=isomer_azimuths,
407
- camera_angles_ele=isomer_elevations,
408
- expansion_weight_stage1=0.1,
409
- init_type="file",
410
- init_verts=vertices,
411
- init_faces=faces,
412
- stage1_steps=0,
413
- stage2_steps=50,
414
- start_edge_len_stage1=0.1,
415
- end_edge_len_stage1=0.02,
416
- start_edge_len_stage2=0.02,
417
- end_edge_len_stage2=0.005,
418
- )
419
-
420
- save_glb_addr = projection(
421
- meshes=meshes,
422
- masks=multi_view_mask,
423
- images=rgb_multi_view,
424
- azimuths=isomer_azimuths,
425
- elevations=isomer_elevations,
426
- weights=isomer_color_weights,
427
- fov=30,
428
- radius=isomer_radius,
429
- save_dir=f"{save_dir_path}/ISOMER/",
430
- )
431
- print(f'saved to {save_glb_addr}')
432
-
433
-
434
-
435
- if __name__ == '__main__':
436
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
live_preview_helpers.py DELETED
@@ -1,167 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from diffusers import FluxPipeline, AutoencoderTiny, FlowMatchEulerDiscreteScheduler
4
- from typing import Any, Dict, List, Optional, Union
5
-
6
- # Helper functions
7
- def calculate_shift(
8
- image_seq_len,
9
- base_seq_len: int = 256,
10
- max_seq_len: int = 4096,
11
- base_shift: float = 0.5,
12
- max_shift: float = 1.16,
13
- ):
14
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
15
- b = base_shift - m * base_seq_len
16
- mu = image_seq_len * m + b
17
- return mu
18
-
19
- def retrieve_timesteps(
20
- scheduler,
21
- num_inference_steps: Optional[int] = None,
22
- device: Optional[Union[str, torch.device]] = None,
23
- timesteps: Optional[List[int]] = None,
24
- sigmas: Optional[List[float]] = None,
25
- **kwargs,
26
- ):
27
- if timesteps is not None and sigmas is not None:
28
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
29
- if timesteps is not None:
30
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
31
- timesteps = scheduler.timesteps
32
- num_inference_steps = len(timesteps)
33
- elif sigmas is not None:
34
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
35
- timesteps = scheduler.timesteps
36
- num_inference_steps = len(timesteps)
37
- else:
38
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
39
- timesteps = scheduler.timesteps
40
- return timesteps, num_inference_steps
41
-
42
- # FLUX pipeline function
43
- @torch.inference_mode()
44
- def flux_pipe_call_that_returns_an_iterable_of_images(
45
- self,
46
- prompt: Union[str, List[str]] = None,
47
- prompt_2: Optional[Union[str, List[str]]] = None,
48
- height: Optional[int] = None,
49
- width: Optional[int] = None,
50
- num_inference_steps: int = 28,
51
- timesteps: List[int] = None,
52
- guidance_scale: float = 3.5,
53
- num_images_per_prompt: Optional[int] = 1,
54
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
55
- latents: Optional[torch.FloatTensor] = None,
56
- prompt_embeds: Optional[torch.FloatTensor] = None,
57
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
58
- output_type: Optional[str] = "pil",
59
- return_dict: bool = True,
60
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
61
- max_sequence_length: int = 512,
62
- good_vae: Optional[Any] = None,
63
- ):
64
- height = height or self.default_sample_size * self.vae_scale_factor
65
- width = width or self.default_sample_size * self.vae_scale_factor
66
-
67
- # 1. Check inputs
68
- self.check_inputs(
69
- prompt,
70
- prompt_2,
71
- height,
72
- width,
73
- prompt_embeds=prompt_embeds,
74
- pooled_prompt_embeds=pooled_prompt_embeds,
75
- max_sequence_length=max_sequence_length,
76
- )
77
-
78
- self._guidance_scale = guidance_scale
79
- self._joint_attention_kwargs = joint_attention_kwargs
80
- self._interrupt = False
81
-
82
- # 2. Define call parameters
83
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
84
- device = self._execution_device
85
-
86
- # 3. Encode prompt
87
- lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
88
- prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
89
- prompt=prompt,
90
- prompt_2=prompt_2,
91
- prompt_embeds=prompt_embeds,
92
- pooled_prompt_embeds=pooled_prompt_embeds,
93
- device=device,
94
- num_images_per_prompt=num_images_per_prompt,
95
- max_sequence_length=max_sequence_length,
96
- lora_scale=lora_scale,
97
- )
98
- # 4. Prepare latent variables
99
- num_channels_latents = self.transformer.config.in_channels // 4
100
- latents, latent_image_ids = self.prepare_latents(
101
- batch_size * num_images_per_prompt,
102
- num_channels_latents,
103
- height,
104
- width,
105
- prompt_embeds.dtype,
106
- device,
107
- generator,
108
- latents,
109
- )
110
- # 5. Prepare timesteps
111
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
112
- image_seq_len = latents.shape[1]
113
- mu = calculate_shift(
114
- image_seq_len,
115
- self.scheduler.config.base_image_seq_len,
116
- self.scheduler.config.max_image_seq_len,
117
- self.scheduler.config.base_shift,
118
- self.scheduler.config.max_shift,
119
- )
120
- timesteps, num_inference_steps = retrieve_timesteps(
121
- self.scheduler,
122
- num_inference_steps,
123
- device,
124
- timesteps,
125
- sigmas,
126
- mu=mu,
127
- )
128
- self._num_timesteps = len(timesteps)
129
-
130
- # Handle guidance
131
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
132
-
133
- # 6. Denoising loop
134
- for i, t in enumerate(timesteps):
135
- print(f"Inference step {i+1}/{num_inference_steps}")
136
- if self.interrupt:
137
- continue
138
-
139
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
-
141
- noise_pred = self.transformer(
142
- hidden_states=latents,
143
- timestep=timestep / 1000,
144
- guidance=guidance,
145
- pooled_projections=pooled_prompt_embeds,
146
- encoder_hidden_states=prompt_embeds,
147
- txt_ids=text_ids,
148
- img_ids=latent_image_ids,
149
- joint_attention_kwargs=self.joint_attention_kwargs,
150
- return_dict=False,
151
- )[0]
152
- # Yield intermediate result
153
- latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
154
- latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
155
- image = self.vae.decode(latents_for_image, return_dict=False)[0]
156
- yield self.image_processor.postprocess(image, output_type=output_type)[0]
157
-
158
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
159
- torch.cuda.empty_cache()
160
-
161
- # Final image using good_vae
162
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
163
- latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
164
- image = good_vae.decode(latents, return_dict=False)[0]
165
- self.maybe_free_model_hooks()
166
- torch.cuda.empty_cache()
167
- yield self.image_processor.postprocess(image, output_type=output_type)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/llm/__pycache__/llm.cpython-310.pyc ADDED
Binary file (5.41 kB). View file
 
models/llm/llm.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ # device = "cuda" # the device to load the model onto
5
+ model_name_or_dir="Qwen/Qwen2-7B-Instruct"
6
+
7
+
8
+ DEFAULT_SYSTEM_PROMPT = """*Given the user's input describing an object, concept, or vague idea, generate a concise and vivid prompt for the diffusion model that portrays a 3D object based solely on that input. Do not include scenes or backgrounds. The prompt should include specific descriptions for each of the four views—front, left side, rear, and right side—that will be displayed in a 2x4 grid (RGB images on the top row and normal maps on the bottom row). Put all descriptions in one single line. Focus on enhancing the cuteness and 3D qualities of the object without including any background or scene elements. Use descriptive adjectives and, if appropriate, stylistic elements to amplify the object's appeal.*
9
+
10
+ ---
11
+
12
+ **Examples: (Please follow the OUTPUT Format of the following examples.)**
13
+
14
+ - **User's Input:** "我喜欢蘑菇."
15
+ A charming 3D mushroom character with a cheerful expression and blushing cheeks, styled in a whimsical, cartoonish manner. Front view displays a wide, happy smile, round eyes, and a polka-dotted cap with a small ladybug perched on top; left side view reveals a miniature satchel with a tiny acorn charm hanging from its stem; rear view shows a cute, tiny backpack decorated with mushroom patterns and a small patch of grass at the base; right side view features a petite, colorful umbrella tucked under its cap, with a ladybug sitting on the handle. No background. Arrange in a 2x4 grid with RGB images on top and normal maps below.
16
+
17
+ - **User's Input:** "画点关于太空的东西吧."
18
+ A delightful 3D astronaut plush toy with oversized, twinkling eyes and a tiny, shiny helmet, styled in an endearing, kawaii fashion. Front view showcases a joyful smile, a sparkly visor, and a round emblem with a star on the chest; left side view highlights a small flag patch on the arm, with a tiny rocket embroidery; rear view reveals a heart-shaped mini oxygen tank with a playful bow attached; right side view displays a waving hand adorned with tiny, glittering stars and a wristband with planets. No background. Display in a 2x4 grid, top row RGB images, bottom row normal maps.
19
+
20
+ - **User's Input:** "老哥,画条龙?"
21
+ A tiny, chubby 3D dragon with a joyful expression and dainty wings, styled in a cute, fantasy-inspired manner. Front view presents large, sparkling eyes, small curved horns, and a toothy grin; left side view features a little pouch hanging from its neck with a golden coin peeking out; rear view reveals a heart-shaped tail adorned with small, shimmering scales; right side view displays a miniature shield with a dragon emblem, and a wing folded in a playful manner. No background. Presented in a 2x4 grid with RGB images above and normal maps below.
22
+
23
+ - **User's Input:** "Maybe a robot?"
24
+ A lovable 3D robot with a round, friendly body and an inviting smile, styled in a sleek, minimalist design. Front view shows glowing, expressive eyes, a cheerful mouth, and a touch-screen panel with a smiley face; left side view highlights a side antenna with a blinking light and a small digital clock display; rear view reveals a charming power pack with colorful circuits and a sticker of a smiling sun; right side view features a mechanical arm holding a tiny flower with a ladybug perched on a petal. No scene elements. Organize in a 2x4 grid, RGB images on the top row, normal maps on the bottom row.
25
+
26
+ ---
27
+
28
+ **Tips:**
29
+
30
+ - **Use Stylized Descriptions:** Mention styles that enhance cuteness (e.g., chibi, kawaii, cartoonish).
31
+
32
+ - **Incorporate Expressive Features:** Emphasize features like big eyes, smiles, or playful accessories.
33
+
34
+ - **Tailor View-Specific Details:** Ensure each view adds unique details to enrich the object's visual appeal.
35
+
36
+ - **Avoid Ambiguity:** Make sure the prompt is specific enough for the model to interpret accurately but doesn't include unnecessary information.
37
+
38
+ OUTPUT THE PROMPT ONLY!
39
+ OUTPUT ENGLISH ONLY! NOT ANY OTHER LANGUAGE, E.G., CHINESE!"""
40
+
41
+ def load_llm_model(model_name_or_dir, torch_dtype='auto', device_map='cpu'):
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ model_name_or_dir,
44
+ torch_dtype=torch_dtype,
45
+ # torch_dtype=torch.float8_e5m2,
46
+ # torch_dtype=torch.float16,
47
+ device_map=device_map
48
+ )
49
+ print(f'set llm model to {model_name_or_dir}')
50
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_dir)
51
+ print(f'set llm tokenizer to {model_name_or_dir}')
52
+ return model, tokenizer
53
+
54
+
55
+ # print(f"Before load llm model: {torch.cuda.memory_allocated() / 1024**3} GB")
56
+ # load_model()
57
+ # print(f"After load llm model: {torch.cuda.memory_allocated() / 1024**3} GB")
58
+
59
+ def get_llm_response(model, tokenizer, user_prompt, seed=None, system_prompt=DEFAULT_SYSTEM_PROMPT):
60
+ # global model
61
+ # global tokenizer
62
+ # load_model()
63
+
64
+ messages = [
65
+ {"role": "system", "content": system_prompt},
66
+ {"role": "user", "content": user_prompt}
67
+ ]
68
+ text = tokenizer.apply_chat_template(
69
+ messages,
70
+ tokenize=False,
71
+ add_generation_prompt=True
72
+ )
73
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
74
+
75
+ if seed is not None:
76
+ torch.manual_seed(seed)
77
+
78
+ # breakpoint()
79
+ generated_ids = model.generate(
80
+ model_inputs.input_ids,
81
+ max_new_tokens=512,
82
+ temperature=0.7,
83
+ )
84
+ generated_ids = [
85
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
86
+ ]
87
+
88
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
89
+
90
+ return response
91
+
92
+ # if __name__ == "__main__":
93
+
94
+ # user_prompt="哈利波特"
95
+ # rsp = get_response(user_prompt, seed=0)
96
+ # print(rsp)
97
+ # breakpoint()
pipeline/custom_pipelines/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
2
+ from .pipeline_flux_img2img import FluxImg2ImgPipeline
3
+ from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
pipeline/custom_pipelines/pipeline_flux_controlnet_image_to_image.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copied from diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
2
+
3
+ import inspect
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from transformers import (
9
+ CLIPTextModel,
10
+ CLIPTokenizer,
11
+ T5EncoderModel,
12
+ T5TokenizerFast,
13
+ )
14
+
15
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
16
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
17
+ from diffusers.models.autoencoders import AutoencoderKL
18
+ from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
19
+ from diffusers.models.transformers import FluxTransformer2DModel
20
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
21
+ from diffusers.utils import (
22
+ USE_PEFT_BACKEND,
23
+ is_torch_xla_available,
24
+ logging,
25
+ replace_example_docstring,
26
+ scale_lora_layers,
27
+ unscale_lora_layers,
28
+ )
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
32
+
33
+
34
+ if is_torch_xla_available():
35
+ import torch_xla.core.xla_model as xm
36
+
37
+ XLA_AVAILABLE = True
38
+ else:
39
+ XLA_AVAILABLE = False
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+ EXAMPLE_DOC_STRING = """
44
+ Examples:
45
+ ```py
46
+ >>> import torch
47
+ >>> from diffusers import FluxControlNetImg2ImgPipeline, FluxControlNetModel
48
+ >>> from diffusers.utils import load_image
49
+
50
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
51
+
52
+ >>> controlnet = FluxControlNetModel.from_pretrained(
53
+ ... "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16
54
+ ... )
55
+
56
+ >>> pipe = FluxControlNetImg2ImgPipeline.from_pretrained(
57
+ ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16
58
+ ... )
59
+
60
+ >>> pipe.text_encoder.to(torch.float16)
61
+ >>> pipe.controlnet.to(torch.float16)
62
+ >>> pipe.to("cuda")
63
+
64
+ >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
65
+ >>> init_image = load_image(
66
+ ... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
67
+ ... )
68
+
69
+ >>> prompt = "A girl in city, 25 years old, cool, futuristic"
70
+ >>> image = pipe(
71
+ ... prompt,
72
+ ... image=init_image,
73
+ ... control_image=control_image,
74
+ ... control_guidance_start=0.2,
75
+ ... control_guidance_end=0.8,
76
+ ... controlnet_conditioning_scale=1.0,
77
+ ... strength=0.7,
78
+ ... num_inference_steps=2,
79
+ ... guidance_scale=3.5,
80
+ ... ).images[0]
81
+ >>> image.save("flux_controlnet_img2img.png")
82
+ ```
83
+ """
84
+
85
+
86
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
87
+ def calculate_shift(
88
+ image_seq_len,
89
+ base_seq_len: int = 256,
90
+ max_seq_len: int = 4096,
91
+ base_shift: float = 0.5,
92
+ max_shift: float = 1.16,
93
+ ):
94
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
95
+ b = base_shift - m * base_seq_len
96
+ mu = image_seq_len * m + b
97
+ return mu
98
+
99
+
100
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
101
+ def retrieve_latents(
102
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
103
+ ):
104
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
105
+ return encoder_output.latent_dist.sample(generator)
106
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
107
+ return encoder_output.latent_dist.mode()
108
+ elif hasattr(encoder_output, "latents"):
109
+ return encoder_output.latents
110
+ else:
111
+ raise AttributeError("Could not access latents of provided encoder_output")
112
+
113
+
114
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
115
+ def retrieve_timesteps(
116
+ scheduler,
117
+ num_inference_steps: Optional[int] = None,
118
+ device: Optional[Union[str, torch.device]] = None,
119
+ timesteps: Optional[List[int]] = None,
120
+ sigmas: Optional[List[float]] = None,
121
+ **kwargs,
122
+ ):
123
+ r"""
124
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
125
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
126
+
127
+ Args:
128
+ scheduler (`SchedulerMixin`):
129
+ The scheduler to get timesteps from.
130
+ num_inference_steps (`int`):
131
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
132
+ must be `None`.
133
+ device (`str` or `torch.device`, *optional*):
134
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
135
+ timesteps (`List[int]`, *optional*):
136
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
137
+ `num_inference_steps` and `sigmas` must be `None`.
138
+ sigmas (`List[float]`, *optional*):
139
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
140
+ `num_inference_steps` and `timesteps` must be `None`.
141
+
142
+ Returns:
143
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
144
+ second element is the number of inference steps.
145
+ """
146
+ if timesteps is not None and sigmas is not None:
147
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
148
+ if timesteps is not None:
149
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
150
+ if not accepts_timesteps:
151
+ raise ValueError(
152
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
153
+ f" timestep schedules. Please check whether you are using the correct scheduler."
154
+ )
155
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
156
+ timesteps = scheduler.timesteps
157
+ num_inference_steps = len(timesteps)
158
+ else:
159
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
160
+ if accept_sigmas:
161
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
162
+ # if not accept_sigmas:
163
+ # raise ValueError(
164
+ # f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
165
+ # f" sigmas schedules. Please check whether you are using the correct scheduler."
166
+ # )
167
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
168
+ timesteps = scheduler.timesteps
169
+ num_inference_steps = len(timesteps)
170
+ else:
171
+ scheduler.set_timesteps(num_inference_steps, device=device)#, **kwargs)
172
+ timesteps = scheduler.timesteps
173
+
174
+ return timesteps, num_inference_steps
175
+
176
+
177
+ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
178
+ r"""
179
+ The Flux controlnet pipeline for image-to-image generation.
180
+
181
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
182
+
183
+ Args:
184
+ transformer ([`FluxTransformer2DModel`]):
185
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
186
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
187
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
188
+ vae ([`AutoencoderKL`]):
189
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
190
+ text_encoder ([`CLIPTextModel`]):
191
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
192
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
193
+ text_encoder_2 ([`T5EncoderModel`]):
194
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
195
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
196
+ tokenizer (`CLIPTokenizer`):
197
+ Tokenizer of class
198
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
199
+ tokenizer_2 (`T5TokenizerFast`):
200
+ Second Tokenizer of class
201
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
202
+ """
203
+
204
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
205
+ _optional_components = []
206
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
207
+
208
+ def __init__(
209
+ self,
210
+ scheduler: FlowMatchEulerDiscreteScheduler,
211
+ vae: AutoencoderKL,
212
+ text_encoder: CLIPTextModel,
213
+ tokenizer: CLIPTokenizer,
214
+ text_encoder_2: T5EncoderModel,
215
+ tokenizer_2: T5TokenizerFast,
216
+ transformer: FluxTransformer2DModel,
217
+ controlnet: Union[
218
+ FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
219
+ ],
220
+ ):
221
+ super().__init__()
222
+ if isinstance(controlnet, (list, tuple)):
223
+ controlnet = FluxMultiControlNetModel(controlnet)
224
+
225
+ self.register_modules(
226
+ vae=vae,
227
+ text_encoder=text_encoder,
228
+ text_encoder_2=text_encoder_2,
229
+ tokenizer=tokenizer,
230
+ tokenizer_2=tokenizer_2,
231
+ transformer=transformer,
232
+ scheduler=scheduler,
233
+ controlnet=controlnet,
234
+ )
235
+ self.vae_scale_factor = (
236
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
237
+ )
238
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
239
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
240
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
241
+ self.tokenizer_max_length = (
242
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
243
+ )
244
+ self.default_sample_size = 128
245
+
246
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
247
+ def _get_t5_prompt_embeds(
248
+ self,
249
+ prompt: Union[str, List[str]] = None,
250
+ num_images_per_prompt: int = 1,
251
+ max_sequence_length: int = 512,
252
+ device: Optional[torch.device] = None,
253
+ dtype: Optional[torch.dtype] = None,
254
+ ):
255
+ device = device or self._execution_device
256
+ dtype = dtype or self.text_encoder.dtype
257
+
258
+ prompt = [prompt] if isinstance(prompt, str) else prompt
259
+ batch_size = len(prompt)
260
+
261
+ if isinstance(self, TextualInversionLoaderMixin):
262
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
263
+
264
+ text_inputs = self.tokenizer_2(
265
+ prompt,
266
+ padding="max_length",
267
+ max_length=max_sequence_length,
268
+ truncation=True,
269
+ return_length=False,
270
+ return_overflowing_tokens=False,
271
+ return_tensors="pt",
272
+ )
273
+ text_input_ids = text_inputs.input_ids
274
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
275
+
276
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
277
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
278
+ logger.warning(
279
+ "The following part of your input was truncated because `max_sequence_length` is set to "
280
+ f" {max_sequence_length} tokens: {removed_text}"
281
+ )
282
+
283
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
284
+
285
+ dtype = self.text_encoder_2.dtype
286
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
287
+
288
+ _, seq_len, _ = prompt_embeds.shape
289
+
290
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
291
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
292
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
293
+
294
+ return prompt_embeds
295
+
296
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
297
+ def _get_clip_prompt_embeds(
298
+ self,
299
+ prompt: Union[str, List[str]],
300
+ num_images_per_prompt: int = 1,
301
+ device: Optional[torch.device] = None,
302
+ ):
303
+ device = device or self._execution_device
304
+
305
+ prompt = [prompt] if isinstance(prompt, str) else prompt
306
+ batch_size = len(prompt)
307
+
308
+ if isinstance(self, TextualInversionLoaderMixin):
309
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
310
+
311
+ text_inputs = self.tokenizer(
312
+ prompt,
313
+ padding="max_length",
314
+ max_length=self.tokenizer_max_length,
315
+ truncation=True,
316
+ return_overflowing_tokens=False,
317
+ return_length=False,
318
+ return_tensors="pt",
319
+ )
320
+
321
+ text_input_ids = text_inputs.input_ids
322
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
323
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
324
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
325
+ logger.warning(
326
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
327
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
328
+ )
329
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
330
+
331
+ # Use pooled output of CLIPTextModel
332
+ prompt_embeds = prompt_embeds.pooler_output
333
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
334
+
335
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
336
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
337
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
338
+
339
+ return prompt_embeds
340
+
341
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
342
+ def encode_prompt(
343
+ self,
344
+ prompt: Union[str, List[str]],
345
+ prompt_2: Union[str, List[str]],
346
+ device: Optional[torch.device] = None,
347
+ num_images_per_prompt: int = 1,
348
+ prompt_embeds: Optional[torch.FloatTensor] = None,
349
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
350
+ max_sequence_length: int = 512,
351
+ lora_scale: Optional[float] = None,
352
+ ):
353
+ r"""
354
+
355
+ Args:
356
+ prompt (`str` or `List[str]`, *optional*):
357
+ prompt to be encoded
358
+ prompt_2 (`str` or `List[str]`, *optional*):
359
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
360
+ used in all text-encoders
361
+ device: (`torch.device`):
362
+ torch device
363
+ num_images_per_prompt (`int`):
364
+ number of images that should be generated per prompt
365
+ prompt_embeds (`torch.FloatTensor`, *optional*):
366
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
367
+ provided, text embeddings will be generated from `prompt` input argument.
368
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
369
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
370
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
371
+ lora_scale (`float`, *optional*):
372
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
373
+ """
374
+ device = device or self._execution_device
375
+
376
+ # set lora scale so that monkey patched LoRA
377
+ # function of text encoder can correctly access it
378
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
379
+ self._lora_scale = lora_scale
380
+
381
+ # dynamically adjust the LoRA scale
382
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
383
+ scale_lora_layers(self.text_encoder, lora_scale)
384
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
385
+ scale_lora_layers(self.text_encoder_2, lora_scale)
386
+
387
+ prompt = [prompt] if isinstance(prompt, str) else prompt
388
+
389
+ if prompt_embeds is None:
390
+ prompt_2 = prompt_2 or prompt
391
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
392
+
393
+ # We only use the pooled prompt output from the CLIPTextModel
394
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
395
+ prompt=prompt,
396
+ device=device,
397
+ num_images_per_prompt=num_images_per_prompt,
398
+ )
399
+ prompt_embeds = self._get_t5_prompt_embeds(
400
+ prompt=prompt_2,
401
+ num_images_per_prompt=num_images_per_prompt,
402
+ max_sequence_length=max_sequence_length,
403
+ device=device,
404
+ )
405
+
406
+ if self.text_encoder is not None:
407
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
408
+ # Retrieve the original scale by scaling back the LoRA layers
409
+ unscale_lora_layers(self.text_encoder, lora_scale)
410
+
411
+ if self.text_encoder_2 is not None:
412
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
413
+ # Retrieve the original scale by scaling back the LoRA layers
414
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
415
+
416
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
417
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
418
+
419
+ return prompt_embeds, pooled_prompt_embeds, text_ids
420
+
421
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
422
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
423
+ if isinstance(generator, list):
424
+ image_latents = [
425
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
426
+ for i in range(image.shape[0])
427
+ ]
428
+ image_latents = torch.cat(image_latents, dim=0)
429
+ else:
430
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
431
+
432
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
433
+
434
+ return image_latents
435
+
436
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
437
+ def get_timesteps(self, num_inference_steps, strength, device):
438
+ # get the original timestep using init_timestep
439
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
440
+
441
+ t_start = int(max(num_inference_steps - init_timestep, 0))
442
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
443
+ if hasattr(self.scheduler, "set_begin_index"):
444
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
445
+
446
+ return timesteps, num_inference_steps - t_start
447
+
448
+ def check_inputs(
449
+ self,
450
+ prompt,
451
+ prompt_2,
452
+ strength,
453
+ height,
454
+ width,
455
+ callback_on_step_end_tensor_inputs,
456
+ prompt_embeds=None,
457
+ pooled_prompt_embeds=None,
458
+ max_sequence_length=None,
459
+ ):
460
+ if strength < 0 or strength > 1:
461
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
462
+
463
+ if height % self.vae_scale_factor * 2 != 0 or width % self.vae_scale_factor * 2 != 0:
464
+ logger.warning(
465
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
466
+ )
467
+
468
+ if callback_on_step_end_tensor_inputs is not None and not all(
469
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
470
+ ):
471
+ raise ValueError(
472
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
473
+ )
474
+
475
+ if prompt is not None and prompt_embeds is not None:
476
+ raise ValueError(
477
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
478
+ " only forward one of the two."
479
+ )
480
+ elif prompt_2 is not None and prompt_embeds is not None:
481
+ raise ValueError(
482
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
483
+ " only forward one of the two."
484
+ )
485
+ elif prompt is None and prompt_embeds is None:
486
+ raise ValueError(
487
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
488
+ )
489
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
490
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
491
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
492
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
493
+
494
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
495
+ raise ValueError(
496
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
497
+ )
498
+
499
+ if max_sequence_length is not None and max_sequence_length > 512:
500
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
501
+
502
+ @staticmethod
503
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
504
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
505
+ latent_image_ids = torch.zeros(height, width, 3)
506
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
507
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
508
+
509
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
510
+
511
+ latent_image_ids = latent_image_ids.reshape(
512
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
513
+ )
514
+
515
+ return latent_image_ids.to(device=device, dtype=dtype)
516
+
517
+ @staticmethod
518
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
519
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
520
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
521
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
522
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
523
+
524
+ return latents
525
+
526
+ @staticmethod
527
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
528
+ def _unpack_latents(latents, height, width, vae_scale_factor):
529
+ batch_size, num_patches, channels = latents.shape
530
+
531
+ # VAE applies 8x compression on images but we must also account for packing which requires
532
+ # latent height and width to be divisible by 2.
533
+ height = 2 * (int(height) // (vae_scale_factor * 2))
534
+ width = 2 * (int(width) // (vae_scale_factor * 2))
535
+
536
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
537
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
538
+
539
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
540
+
541
+ return latents
542
+
543
+ # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
544
+ def prepare_latents(
545
+ self,
546
+ image,
547
+ timestep,
548
+ batch_size,
549
+ num_channels_latents,
550
+ height,
551
+ width,
552
+ dtype,
553
+ device,
554
+ generator,
555
+ latents=None,
556
+ ):
557
+ if isinstance(generator, list) and len(generator) != batch_size:
558
+ raise ValueError(
559
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
560
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
561
+ )
562
+
563
+ # VAE applies 8x compression on images but we must also account for packing which requires
564
+ # latent height and width to be divisible by 2.
565
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
566
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
567
+ shape = (batch_size, num_channels_latents, height, width)
568
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
569
+
570
+ if latents is not None:
571
+ return latents.to(device=device, dtype=dtype), latent_image_ids
572
+
573
+ image = image.to(device=device, dtype=dtype)
574
+ image_latents = self._encode_vae_image(image=image, generator=generator)
575
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
576
+ # expand init_latents for batch_size
577
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
578
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
579
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
580
+ raise ValueError(
581
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
582
+ )
583
+ else:
584
+ image_latents = torch.cat([image_latents], dim=0)
585
+
586
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
587
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
588
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
589
+ return latents, latent_image_ids
590
+
591
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
592
+ def prepare_image(
593
+ self,
594
+ image,
595
+ width,
596
+ height,
597
+ batch_size,
598
+ num_images_per_prompt,
599
+ device,
600
+ dtype,
601
+ do_classifier_free_guidance=False,
602
+ guess_mode=False,
603
+ ):
604
+ if isinstance(image, torch.Tensor):
605
+ pass
606
+ else:
607
+ image = self.image_processor.preprocess(image, height=height, width=width)
608
+
609
+ image_batch_size = image.shape[0]
610
+
611
+ if image_batch_size == 1:
612
+ repeat_by = batch_size
613
+ else:
614
+ # image batch size is the same as prompt batch size
615
+ repeat_by = num_images_per_prompt
616
+
617
+ image = image.repeat_interleave(repeat_by, dim=0)
618
+
619
+ image = image.to(device=device, dtype=dtype)
620
+
621
+ if do_classifier_free_guidance and not guess_mode:
622
+ image = torch.cat([image] * 2)
623
+
624
+ return image
625
+
626
+ @property
627
+ def guidance_scale(self):
628
+ return self._guidance_scale
629
+
630
+ @property
631
+ def joint_attention_kwargs(self):
632
+ return self._joint_attention_kwargs
633
+
634
+ @property
635
+ def num_timesteps(self):
636
+ return self._num_timesteps
637
+
638
+ @property
639
+ def interrupt(self):
640
+ return self._interrupt
641
+
642
+ @torch.no_grad()
643
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
644
+ def __call__(
645
+ self,
646
+ prompt: Union[str, List[str]] = None,
647
+ prompt_2: Optional[Union[str, List[str]]] = None,
648
+ image: PipelineImageInput = None,
649
+ control_image: PipelineImageInput = None,
650
+ height: Optional[int] = None,
651
+ width: Optional[int] = None,
652
+ strength: float = 0.6,
653
+ num_inference_steps: int = 28,
654
+ sigmas: Optional[List[float]] = None,
655
+ guidance_scale: float = 7.0,
656
+ control_guidance_start: Union[float, List[float]] = 0.0,
657
+ control_guidance_end: Union[float, List[float]] = 1.0,
658
+ control_mode: Optional[Union[int, List[int]]] = None,
659
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
660
+ num_images_per_prompt: Optional[int] = 1,
661
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
662
+ latents: Optional[torch.FloatTensor] = None,
663
+ prompt_embeds: Optional[torch.FloatTensor] = None,
664
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
665
+ output_type: Optional[str] = "pil",
666
+ return_dict: bool = True,
667
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
668
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
669
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
670
+ max_sequence_length: int = 512,
671
+ ):
672
+ """
673
+ Function invoked when calling the pipeline for generation.
674
+
675
+ Args:
676
+ prompt (`str` or `List[str]`, *optional*):
677
+ The prompt or prompts to guide the image generation.
678
+ prompt_2 (`str` or `List[str]`, *optional*):
679
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
680
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
681
+ The image(s) to modify with the pipeline.
682
+ control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
683
+ The ControlNet input condition. Image to control the generation.
684
+ height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
685
+ The height in pixels of the generated image.
686
+ width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
687
+ The width in pixels of the generated image.
688
+ strength (`float`, *optional*, defaults to 0.6):
689
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
690
+ num_inference_steps (`int`, *optional*, defaults to 28):
691
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
692
+ expense of slower inference.
693
+ sigmas (`List[float]`, *optional*):
694
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
695
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
696
+ will be used.
697
+ guidance_scale (`float`, *optional*, defaults to 7.0):
698
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
699
+ control_mode (`int` or `List[int]`, *optional*):
700
+ The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
701
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
702
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
703
+ to the residual in the original transformer.
704
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
705
+ The number of images to generate per prompt.
706
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
707
+ One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to
708
+ make generation deterministic.
709
+ latents (`torch.FloatTensor`, *optional*):
710
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
711
+ generation. Can be used to tweak the same generation with different prompts.
712
+ prompt_embeds (`torch.FloatTensor`, *optional*):
713
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
714
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
715
+ Pre-generated pooled text embeddings.
716
+ output_type (`str`, *optional*, defaults to `"pil"`):
717
+ The output format of the generate image. Choose between `PIL.Image` or `np.array`.
718
+ return_dict (`bool`, *optional*, defaults to `True`):
719
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
720
+ joint_attention_kwargs (`dict`, *optional*):
721
+ Additional keyword arguments to be passed to the joint attention mechanism.
722
+ callback_on_step_end (`Callable`, *optional*):
723
+ A function that calls at the end of each denoising step during the inference.
724
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
725
+ The list of tensor inputs for the `callback_on_step_end` function.
726
+ max_sequence_length (`int`, *optional*, defaults to 512):
727
+ The maximum length of the sequence to be generated.
728
+
729
+ Examples:
730
+
731
+ Returns:
732
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
733
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
734
+ images.
735
+ """
736
+ height = height or self.default_sample_size * self.vae_scale_factor
737
+ width = width or self.default_sample_size * self.vae_scale_factor
738
+
739
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
740
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
741
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
742
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
743
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
744
+ mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
745
+ control_guidance_start, control_guidance_end = (
746
+ mult * [control_guidance_start],
747
+ mult * [control_guidance_end],
748
+ )
749
+
750
+ self.check_inputs(
751
+ prompt,
752
+ prompt_2,
753
+ strength,
754
+ height,
755
+ width,
756
+ callback_on_step_end_tensor_inputs,
757
+ prompt_embeds=prompt_embeds,
758
+ pooled_prompt_embeds=pooled_prompt_embeds,
759
+ max_sequence_length=max_sequence_length,
760
+ )
761
+
762
+ self._guidance_scale = guidance_scale
763
+ self._joint_attention_kwargs = joint_attention_kwargs
764
+ self._interrupt = False
765
+
766
+ if prompt is not None and isinstance(prompt, str):
767
+ batch_size = 1
768
+ elif prompt is not None and isinstance(prompt, list):
769
+ batch_size = len(prompt)
770
+ else:
771
+ batch_size = prompt_embeds.shape[0]
772
+
773
+ device = self._execution_device
774
+ dtype = self.transformer.dtype
775
+
776
+ lora_scale = (
777
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
778
+ )
779
+ (
780
+ prompt_embeds,
781
+ pooled_prompt_embeds,
782
+ text_ids,
783
+ ) = self.encode_prompt(
784
+ prompt=prompt,
785
+ prompt_2=prompt_2,
786
+ prompt_embeds=prompt_embeds,
787
+ pooled_prompt_embeds=pooled_prompt_embeds,
788
+ device=device,
789
+ num_images_per_prompt=num_images_per_prompt,
790
+ max_sequence_length=max_sequence_length,
791
+ lora_scale=lora_scale,
792
+ )
793
+
794
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
795
+ init_image = init_image.to(dtype=torch.float32)
796
+
797
+ num_channels_latents = self.transformer.config.in_channels // 4
798
+
799
+ if isinstance(self.controlnet, FluxControlNetModel):
800
+ control_image = self.prepare_image(
801
+ image=control_image,
802
+ width=width,
803
+ height=height,
804
+ batch_size=batch_size * num_images_per_prompt,
805
+ num_images_per_prompt=num_images_per_prompt,
806
+ device=device,
807
+ dtype=self.vae.dtype,
808
+ )
809
+ height, width = control_image.shape[-2:]
810
+
811
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
812
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
813
+
814
+ height_control_image, width_control_image = control_image.shape[2:]
815
+ control_image = self._pack_latents(
816
+ control_image,
817
+ batch_size * num_images_per_prompt,
818
+ num_channels_latents,
819
+ height_control_image,
820
+ width_control_image,
821
+ )
822
+
823
+ if control_mode is not None:
824
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
825
+ control_mode = control_mode.reshape([-1, 1])
826
+
827
+ elif isinstance(self.controlnet, FluxMultiControlNetModel):
828
+ control_images = []
829
+
830
+ for control_image_ in control_image:
831
+ control_image_ = self.prepare_image(
832
+ image=control_image_,
833
+ width=width,
834
+ height=height,
835
+ batch_size=batch_size * num_images_per_prompt,
836
+ num_images_per_prompt=num_images_per_prompt,
837
+ device=device,
838
+ dtype=self.vae.dtype,
839
+ )
840
+ height, width = control_image_.shape[-2:]
841
+
842
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
843
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
844
+
845
+ height_control_image, width_control_image = control_image_.shape[2:]
846
+ control_image_ = self._pack_latents(
847
+ control_image_,
848
+ batch_size * num_images_per_prompt,
849
+ num_channels_latents,
850
+ height_control_image,
851
+ width_control_image,
852
+ )
853
+
854
+ control_images.append(control_image_)
855
+
856
+ control_image = control_images
857
+
858
+ control_mode_ = []
859
+ if isinstance(control_mode, list):
860
+ for cmode in control_mode:
861
+ if cmode is None:
862
+ control_mode_.append(-1)
863
+ else:
864
+ control_mode_.append(cmode)
865
+ control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
866
+ control_mode = control_mode.reshape([-1, 1])
867
+
868
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
869
+ mu = calculate_shift(
870
+ image_seq_len,
871
+ self.scheduler.config.base_image_seq_len,
872
+ self.scheduler.config.max_image_seq_len,
873
+ self.scheduler.config.base_shift,
874
+ self.scheduler.config.max_shift,
875
+ )
876
+
877
+ # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
878
+ timesteps, num_inference_steps = retrieve_timesteps(
879
+ self.scheduler,
880
+ num_inference_steps,
881
+ device,
882
+ sigmas=sigmas,
883
+ mu=mu,
884
+ )
885
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
886
+
887
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
888
+ latents, latent_image_ids = self.prepare_latents(
889
+ init_image,
890
+ latent_timestep,
891
+ batch_size * num_images_per_prompt,
892
+ num_channels_latents,
893
+ height,
894
+ width,
895
+ prompt_embeds.dtype,
896
+ device,
897
+ generator,
898
+ latents,
899
+ )
900
+
901
+ controlnet_keep = []
902
+ for i in range(len(timesteps)):
903
+ keeps = [
904
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
905
+ for s, e in zip(control_guidance_start, control_guidance_end)
906
+ ]
907
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
908
+
909
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
910
+ self._num_timesteps = len(timesteps)
911
+
912
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
913
+ for i, t in enumerate(timesteps):
914
+ if self.interrupt:
915
+ continue
916
+
917
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
918
+
919
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
920
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
921
+ else:
922
+ use_guidance = self.controlnet.config.guidance_embeds
923
+
924
+ guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
925
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
926
+
927
+ if isinstance(controlnet_keep[i], list):
928
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
929
+ else:
930
+ controlnet_cond_scale = controlnet_conditioning_scale
931
+ if isinstance(controlnet_cond_scale, list):
932
+ controlnet_cond_scale = controlnet_cond_scale[0]
933
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
934
+
935
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
936
+ hidden_states=latents,
937
+ controlnet_cond=control_image,
938
+ controlnet_mode=control_mode,
939
+ conditioning_scale=cond_scale,
940
+ timestep=timestep / 1000,
941
+ guidance=guidance,
942
+ pooled_projections=pooled_prompt_embeds,
943
+ encoder_hidden_states=prompt_embeds,
944
+ txt_ids=text_ids,
945
+ img_ids=latent_image_ids,
946
+ joint_attention_kwargs=self.joint_attention_kwargs,
947
+ return_dict=False,
948
+ )
949
+
950
+ guidance = (
951
+ torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
952
+ )
953
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
954
+
955
+ noise_pred = self.transformer(
956
+ hidden_states=latents,
957
+ timestep=timestep / 1000,
958
+ guidance=guidance,
959
+ pooled_projections=pooled_prompt_embeds,
960
+ encoder_hidden_states=prompt_embeds,
961
+ controlnet_block_samples=controlnet_block_samples,
962
+ controlnet_single_block_samples=controlnet_single_block_samples,
963
+ txt_ids=text_ids,
964
+ img_ids=latent_image_ids,
965
+ joint_attention_kwargs=self.joint_attention_kwargs,
966
+ return_dict=False,
967
+ )[0]
968
+
969
+ latents_dtype = latents.dtype
970
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
971
+
972
+ if latents.dtype != latents_dtype:
973
+ if torch.backends.mps.is_available():
974
+ latents = latents.to(latents_dtype)
975
+
976
+ if callback_on_step_end is not None:
977
+ callback_kwargs = {}
978
+ for k in callback_on_step_end_tensor_inputs:
979
+ callback_kwargs[k] = locals()[k]
980
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
981
+
982
+ latents = callback_outputs.pop("latents", latents)
983
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
984
+
985
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
986
+ progress_bar.update()
987
+
988
+ if XLA_AVAILABLE:
989
+ xm.mark_step()
990
+
991
+ if output_type == "latent":
992
+ image = latents
993
+ else:
994
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
995
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
996
+ image = self.vae.decode(latents, return_dict=False)[0]
997
+ image = self.image_processor.postprocess(image, output_type=output_type)
998
+
999
+ self.maybe_free_model_hooks()
1000
+
1001
+ if not return_dict:
1002
+ return (image,)
1003
+
1004
+ return FluxPipelineOutput(images=image)
pipeline/custom_pipelines/pipeline_flux_img2img.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # copied from diffusers/src/diffusers/pipeline/flux/pipeline_flux_img2img.py
3
+
4
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import inspect
19
+ from typing import Any, Callable, Dict, List, Optional, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
24
+
25
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
26
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
27
+ from diffusers.models.autoencoders import AutoencoderKL
28
+ from diffusers.models.transformers import FluxTransformer2DModel
29
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
30
+ from diffusers.utils import (
31
+ USE_PEFT_BACKEND,
32
+ is_torch_xla_available,
33
+ logging,
34
+ replace_example_docstring,
35
+ scale_lora_layers,
36
+ unscale_lora_layers,
37
+ )
38
+ from diffusers.utils.torch_utils import randn_tensor
39
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
40
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
41
+
42
+
43
+ if is_torch_xla_available():
44
+ import torch_xla.core.xla_model as xm
45
+
46
+ XLA_AVAILABLE = True
47
+ else:
48
+ XLA_AVAILABLE = False
49
+
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+ EXAMPLE_DOC_STRING = """
54
+ Examples:
55
+ ```py
56
+ >>> import torch
57
+
58
+ >>> from diffusers import FluxImg2ImgPipeline
59
+ >>> from diffusers.utils import load_image
60
+
61
+ >>> device = "cuda"
62
+ >>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
63
+ >>> pipe = pipe.to(device)
64
+
65
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
66
+ >>> init_image = load_image(url).resize((1024, 1024))
67
+
68
+ >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
69
+
70
+ >>> images = pipe(
71
+ ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0
72
+ ... ).images[0]
73
+ ```
74
+ """
75
+
76
+
77
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
78
+ def calculate_shift(
79
+ image_seq_len,
80
+ base_seq_len: int = 256,
81
+ max_seq_len: int = 4096,
82
+ base_shift: float = 0.5,
83
+ max_shift: float = 1.16,
84
+ ):
85
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
86
+ b = base_shift - m * base_seq_len
87
+ mu = image_seq_len * m + b
88
+ return mu
89
+
90
+
91
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
92
+ def retrieve_latents(
93
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
94
+ ):
95
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
96
+ return encoder_output.latent_dist.sample(generator)
97
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
98
+ return encoder_output.latent_dist.mode()
99
+ elif hasattr(encoder_output, "latents"):
100
+ return encoder_output.latents
101
+ else:
102
+ raise AttributeError("Could not access latents of provided encoder_output")
103
+
104
+
105
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
106
+ def retrieve_timesteps(
107
+ scheduler,
108
+ num_inference_steps: Optional[int] = None,
109
+ device: Optional[Union[str, torch.device]] = None,
110
+ timesteps: Optional[List[int]] = None,
111
+ sigmas: Optional[List[float]] = None,
112
+ **kwargs,
113
+ ):
114
+ r"""
115
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
116
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
117
+
118
+ Args:
119
+ scheduler (`SchedulerMixin`):
120
+ The scheduler to get timesteps from.
121
+ num_inference_steps (`int`):
122
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
123
+ must be `None`.
124
+ device (`str` or `torch.device`, *optional*):
125
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
126
+ timesteps (`List[int]`, *optional*):
127
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
128
+ `num_inference_steps` and `sigmas` must be `None`.
129
+ sigmas (`List[float]`, *optional*):
130
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
131
+ `num_inference_steps` and `timesteps` must be `None`.
132
+
133
+ Returns:
134
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
135
+ second element is the number of inference steps.
136
+ """
137
+ if timesteps is not None and sigmas is not None:
138
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
139
+ if timesteps is not None:
140
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
141
+ if not accepts_timesteps:
142
+ raise ValueError(
143
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
144
+ f" timestep schedules. Please check whether you are using the correct scheduler."
145
+ )
146
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
147
+ timesteps = scheduler.timesteps
148
+ num_inference_steps = len(timesteps)
149
+ else:
150
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
151
+ if accept_sigmas:
152
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
153
+ # if not accept_sigmas:
154
+ # raise ValueError(
155
+ # f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
156
+ # f" sigmas schedules. Please check whether you are using the correct scheduler."
157
+ # )
158
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
159
+ timesteps = scheduler.timesteps
160
+ num_inference_steps = len(timesteps)
161
+ else:
162
+ scheduler.set_timesteps(num_inference_steps, device=device)#, **kwargs)
163
+ timesteps = scheduler.timesteps
164
+
165
+ return timesteps, num_inference_steps
166
+
167
+
168
+ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
169
+ r"""
170
+ The Flux pipeline for image inpainting.
171
+
172
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
173
+
174
+ Args:
175
+ transformer ([`FluxTransformer2DModel`]):
176
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
177
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
178
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
179
+ vae ([`AutoencoderKL`]):
180
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
181
+ text_encoder ([`CLIPTextModel`]):
182
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
183
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
184
+ text_encoder_2 ([`T5EncoderModel`]):
185
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
186
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
187
+ tokenizer (`CLIPTokenizer`):
188
+ Tokenizer of class
189
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
190
+ tokenizer_2 (`T5TokenizerFast`):
191
+ Second Tokenizer of class
192
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
193
+ """
194
+
195
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
196
+ _optional_components = []
197
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
198
+
199
+ def __init__(
200
+ self,
201
+ scheduler: FlowMatchEulerDiscreteScheduler,
202
+ vae: AutoencoderKL,
203
+ text_encoder: CLIPTextModel,
204
+ tokenizer: CLIPTokenizer,
205
+ text_encoder_2: T5EncoderModel,
206
+ tokenizer_2: T5TokenizerFast,
207
+ transformer: FluxTransformer2DModel,
208
+ ):
209
+ super().__init__()
210
+
211
+ self.register_modules(
212
+ vae=vae,
213
+ text_encoder=text_encoder,
214
+ text_encoder_2=text_encoder_2,
215
+ tokenizer=tokenizer,
216
+ tokenizer_2=tokenizer_2,
217
+ transformer=transformer,
218
+ scheduler=scheduler,
219
+ )
220
+ self.vae_scale_factor = (
221
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
222
+ )
223
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
224
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
225
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
226
+ self.tokenizer_max_length = (
227
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
228
+ )
229
+ self.default_sample_size = 128
230
+
231
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
232
+ def _get_t5_prompt_embeds(
233
+ self,
234
+ prompt: Union[str, List[str]] = None,
235
+ num_images_per_prompt: int = 1,
236
+ max_sequence_length: int = 512,
237
+ device: Optional[torch.device] = None,
238
+ dtype: Optional[torch.dtype] = None,
239
+ ):
240
+ device = device or self._execution_device
241
+ dtype = dtype or self.text_encoder.dtype
242
+
243
+ prompt = [prompt] if isinstance(prompt, str) else prompt
244
+ batch_size = len(prompt)
245
+
246
+ if isinstance(self, TextualInversionLoaderMixin):
247
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
248
+
249
+ text_inputs = self.tokenizer_2(
250
+ prompt,
251
+ padding="max_length",
252
+ max_length=max_sequence_length,
253
+ truncation=True,
254
+ return_length=False,
255
+ return_overflowing_tokens=False,
256
+ return_tensors="pt",
257
+ )
258
+ text_input_ids = text_inputs.input_ids
259
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
260
+
261
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
262
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
263
+ logger.warning(
264
+ "The following part of your input was truncated because `max_sequence_length` is set to "
265
+ f" {max_sequence_length} tokens: {removed_text}"
266
+ )
267
+
268
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
269
+
270
+ dtype = self.text_encoder_2.dtype
271
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
272
+
273
+ _, seq_len, _ = prompt_embeds.shape
274
+
275
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
276
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
277
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
278
+
279
+ return prompt_embeds
280
+
281
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
282
+ def _get_clip_prompt_embeds(
283
+ self,
284
+ prompt: Union[str, List[str]],
285
+ num_images_per_prompt: int = 1,
286
+ device: Optional[torch.device] = None,
287
+ ):
288
+ device = device or self._execution_device
289
+
290
+ prompt = [prompt] if isinstance(prompt, str) else prompt
291
+ batch_size = len(prompt)
292
+
293
+ if isinstance(self, TextualInversionLoaderMixin):
294
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
295
+
296
+ text_inputs = self.tokenizer(
297
+ prompt,
298
+ padding="max_length",
299
+ max_length=self.tokenizer_max_length,
300
+ truncation=True,
301
+ return_overflowing_tokens=False,
302
+ return_length=False,
303
+ return_tensors="pt",
304
+ )
305
+
306
+ text_input_ids = text_inputs.input_ids
307
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
308
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
309
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
310
+ logger.warning(
311
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
312
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
313
+ )
314
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
315
+
316
+ # Use pooled output of CLIPTextModel
317
+ prompt_embeds = prompt_embeds.pooler_output
318
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
319
+
320
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
321
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
322
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
323
+
324
+ return prompt_embeds
325
+
326
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
327
+ def encode_prompt(
328
+ self,
329
+ prompt: Union[str, List[str]],
330
+ prompt_2: Union[str, List[str]],
331
+ device: Optional[torch.device] = None,
332
+ num_images_per_prompt: int = 1,
333
+ prompt_embeds: Optional[torch.FloatTensor] = None,
334
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
335
+ max_sequence_length: int = 512,
336
+ lora_scale: Optional[float] = None,
337
+ ):
338
+ r"""
339
+
340
+ Args:
341
+ prompt (`str` or `List[str]`, *optional*):
342
+ prompt to be encoded
343
+ prompt_2 (`str` or `List[str]`, *optional*):
344
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
345
+ used in all text-encoders
346
+ device: (`torch.device`):
347
+ torch device
348
+ num_images_per_prompt (`int`):
349
+ number of images that should be generated per prompt
350
+ prompt_embeds (`torch.FloatTensor`, *optional*):
351
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
352
+ provided, text embeddings will be generated from `prompt` input argument.
353
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
354
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
355
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
356
+ lora_scale (`float`, *optional*):
357
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
358
+ """
359
+ device = device or self._execution_device
360
+
361
+ # set lora scale so that monkey patched LoRA
362
+ # function of text encoder can correctly access it
363
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
364
+ self._lora_scale = lora_scale
365
+
366
+ # dynamically adjust the LoRA scale
367
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
368
+ scale_lora_layers(self.text_encoder, lora_scale)
369
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
370
+ scale_lora_layers(self.text_encoder_2, lora_scale)
371
+
372
+ prompt = [prompt] if isinstance(prompt, str) else prompt
373
+
374
+ if prompt_embeds is None:
375
+ prompt_2 = prompt_2 or prompt
376
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
377
+
378
+ # We only use the pooled prompt output from the CLIPTextModel
379
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
380
+ prompt=prompt,
381
+ device=device,
382
+ num_images_per_prompt=num_images_per_prompt,
383
+ )
384
+ prompt_embeds = self._get_t5_prompt_embeds(
385
+ prompt=prompt_2,
386
+ num_images_per_prompt=num_images_per_prompt,
387
+ max_sequence_length=max_sequence_length,
388
+ device=device,
389
+ )
390
+
391
+ if self.text_encoder is not None:
392
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
393
+ # Retrieve the original scale by scaling back the LoRA layers
394
+ unscale_lora_layers(self.text_encoder, lora_scale)
395
+
396
+ if self.text_encoder_2 is not None:
397
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
398
+ # Retrieve the original scale by scaling back the LoRA layers
399
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
400
+
401
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
402
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
403
+
404
+ return prompt_embeds, pooled_prompt_embeds, text_ids
405
+
406
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
407
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
408
+ if isinstance(generator, list):
409
+ image_latents = [
410
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
411
+ for i in range(image.shape[0])
412
+ ]
413
+ image_latents = torch.cat(image_latents, dim=0)
414
+ else:
415
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
416
+
417
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
418
+
419
+ return image_latents
420
+
421
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
422
+ def get_timesteps(self, num_inference_steps, strength, device):
423
+ # get the original timestep using init_timestep
424
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
425
+
426
+ t_start = int(max(num_inference_steps - init_timestep, 0))
427
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
428
+ if hasattr(self.scheduler, "set_begin_index"):
429
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
430
+
431
+ return timesteps, num_inference_steps - t_start
432
+
433
+ def check_inputs(
434
+ self,
435
+ prompt,
436
+ prompt_2,
437
+ strength,
438
+ height,
439
+ width,
440
+ prompt_embeds=None,
441
+ pooled_prompt_embeds=None,
442
+ callback_on_step_end_tensor_inputs=None,
443
+ max_sequence_length=None,
444
+ ):
445
+ if strength < 0 or strength > 1:
446
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
447
+
448
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
449
+ logger.warning(
450
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
451
+ )
452
+
453
+ if callback_on_step_end_tensor_inputs is not None and not all(
454
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
455
+ ):
456
+ raise ValueError(
457
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
458
+ )
459
+
460
+ if prompt is not None and prompt_embeds is not None:
461
+ raise ValueError(
462
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
463
+ " only forward one of the two."
464
+ )
465
+ elif prompt_2 is not None and prompt_embeds is not None:
466
+ raise ValueError(
467
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
468
+ " only forward one of the two."
469
+ )
470
+ elif prompt is None and prompt_embeds is None:
471
+ raise ValueError(
472
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
473
+ )
474
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
475
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
476
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
477
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
478
+
479
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
480
+ raise ValueError(
481
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
482
+ )
483
+
484
+ if max_sequence_length is not None and max_sequence_length > 512:
485
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
486
+
487
+ @staticmethod
488
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
489
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
490
+ latent_image_ids = torch.zeros(height, width, 3)
491
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
492
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
493
+
494
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
495
+
496
+ latent_image_ids = latent_image_ids.reshape(
497
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
498
+ )
499
+
500
+ return latent_image_ids.to(device=device, dtype=dtype)
501
+
502
+ @staticmethod
503
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
504
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
505
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
506
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
507
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
508
+
509
+ return latents
510
+
511
+ @staticmethod
512
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
513
+ def _unpack_latents(latents, height, width, vae_scale_factor):
514
+ batch_size, num_patches, channels = latents.shape
515
+
516
+ # VAE applies 8x compression on images but we must also account for packing which requires
517
+ # latent height and width to be divisible by 2.
518
+ height = 2 * (int(height) // (vae_scale_factor * 2))
519
+ width = 2 * (int(width) // (vae_scale_factor * 2))
520
+
521
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
522
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
523
+
524
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
525
+
526
+ return latents
527
+
528
+ def prepare_latents(
529
+ self,
530
+ image,
531
+ timestep,
532
+ batch_size,
533
+ num_channels_latents,
534
+ height,
535
+ width,
536
+ dtype,
537
+ device,
538
+ generator,
539
+ latents=None,
540
+ ):
541
+ if isinstance(generator, list) and len(generator) != batch_size:
542
+ raise ValueError(
543
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
544
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
545
+ )
546
+
547
+ # VAE applies 8x compression on images but we must also account for packing which requires
548
+ # latent height and width to be divisible by 2.
549
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
550
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
551
+ shape = (batch_size, num_channels_latents, height, width)
552
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
553
+
554
+ if latents is not None:
555
+ return latents.to(device=device, dtype=dtype), latent_image_ids
556
+
557
+ image = image.to(device=device, dtype=dtype)
558
+ image_latents = self._encode_vae_image(image=image, generator=generator)
559
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
560
+ # expand init_latents for batch_size
561
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
562
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
563
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
564
+ raise ValueError(
565
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
566
+ )
567
+ else:
568
+ image_latents = torch.cat([image_latents], dim=0)
569
+
570
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
571
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
572
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
573
+ return latents, latent_image_ids
574
+
575
+ @property
576
+ def guidance_scale(self):
577
+ return self._guidance_scale
578
+
579
+ @property
580
+ def joint_attention_kwargs(self):
581
+ return self._joint_attention_kwargs
582
+
583
+ @property
584
+ def num_timesteps(self):
585
+ return self._num_timesteps
586
+
587
+ @property
588
+ def interrupt(self):
589
+ return self._interrupt
590
+
591
+ @torch.no_grad()
592
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
593
+ def __call__(
594
+ self,
595
+ prompt: Union[str, List[str]] = None,
596
+ prompt_2: Optional[Union[str, List[str]]] = None,
597
+ image: PipelineImageInput = None,
598
+ height: Optional[int] = None,
599
+ width: Optional[int] = None,
600
+ strength: float = 0.6,
601
+ num_inference_steps: int = 28,
602
+ sigmas: Optional[List[float]] = None,
603
+ guidance_scale: float = 7.0,
604
+ num_images_per_prompt: Optional[int] = 1,
605
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
606
+ latents: Optional[torch.FloatTensor] = None,
607
+ prompt_embeds: Optional[torch.FloatTensor] = None,
608
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
609
+ output_type: Optional[str] = "pil",
610
+ return_dict: bool = True,
611
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
612
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
613
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
614
+ max_sequence_length: int = 512,
615
+ ):
616
+ r"""
617
+ Function invoked when calling the pipeline for generation.
618
+
619
+ Args:
620
+ prompt (`str` or `List[str]`, *optional*):
621
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
622
+ instead.
623
+ prompt_2 (`str` or `List[str]`, *optional*):
624
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
625
+ will be used instead
626
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
627
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
628
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
629
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
630
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
631
+ latents as `image`, but if passing latents directly it is not encoded again.
632
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
633
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
634
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
635
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
636
+ strength (`float`, *optional*, defaults to 1.0):
637
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
638
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
639
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
640
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
641
+ essentially ignores `image`.
642
+ num_inference_steps (`int`, *optional*, defaults to 50):
643
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
644
+ expense of slower inference.
645
+ sigmas (`List[float]`, *optional*):
646
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
647
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
648
+ will be used.
649
+ guidance_scale (`float`, *optional*, defaults to 7.0):
650
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
651
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
652
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
653
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
654
+ usually at the expense of lower image quality.
655
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
656
+ The number of images to generate per prompt.
657
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
658
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
659
+ to make generation deterministic.
660
+ latents (`torch.FloatTensor`, *optional*):
661
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
662
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
663
+ tensor will ge generated by sampling using the supplied random `generator`.
664
+ prompt_embeds (`torch.FloatTensor`, *optional*):
665
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
666
+ provided, text embeddings will be generated from `prompt` input argument.
667
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
668
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
669
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
670
+ output_type (`str`, *optional*, defaults to `"pil"`):
671
+ The output format of the generate image. Choose between
672
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
673
+ return_dict (`bool`, *optional*, defaults to `True`):
674
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
675
+ joint_attention_kwargs (`dict`, *optional*):
676
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
677
+ `self.processor` in
678
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
679
+ callback_on_step_end (`Callable`, *optional*):
680
+ A function that calls at the end of each denoising steps during the inference. The function is called
681
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
682
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
683
+ `callback_on_step_end_tensor_inputs`.
684
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
685
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
686
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
687
+ `._callback_tensor_inputs` attribute of your pipeline class.
688
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
689
+
690
+ Examples:
691
+
692
+ Returns:
693
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
694
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
695
+ images.
696
+ """
697
+
698
+ height = height or self.default_sample_size * self.vae_scale_factor
699
+ width = width or self.default_sample_size * self.vae_scale_factor
700
+
701
+ # 1. Check inputs. Raise error if not correct
702
+ self.check_inputs(
703
+ prompt,
704
+ prompt_2,
705
+ strength,
706
+ height,
707
+ width,
708
+ prompt_embeds=prompt_embeds,
709
+ pooled_prompt_embeds=pooled_prompt_embeds,
710
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
711
+ max_sequence_length=max_sequence_length,
712
+ )
713
+
714
+ self._guidance_scale = guidance_scale
715
+ self._joint_attention_kwargs = joint_attention_kwargs
716
+ self._interrupt = False
717
+
718
+ # 2. Preprocess image
719
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
720
+ init_image = init_image.to(dtype=torch.float32)
721
+
722
+ # 3. Define call parameters
723
+ if prompt is not None and isinstance(prompt, str):
724
+ batch_size = 1
725
+ elif prompt is not None and isinstance(prompt, list):
726
+ batch_size = len(prompt)
727
+ else:
728
+ batch_size = prompt_embeds.shape[0]
729
+
730
+ device = self._execution_device
731
+
732
+ lora_scale = (
733
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
734
+ )
735
+ (
736
+ prompt_embeds,
737
+ pooled_prompt_embeds,
738
+ text_ids,
739
+ ) = self.encode_prompt(
740
+ prompt=prompt,
741
+ prompt_2=prompt_2,
742
+ prompt_embeds=prompt_embeds,
743
+ pooled_prompt_embeds=pooled_prompt_embeds,
744
+ device=device,
745
+ num_images_per_prompt=num_images_per_prompt,
746
+ max_sequence_length=max_sequence_length,
747
+ lora_scale=lora_scale,
748
+ )
749
+
750
+ # 4.Prepare timesteps
751
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
752
+ mu = calculate_shift(
753
+ image_seq_len,
754
+ self.scheduler.config.base_image_seq_len,
755
+ self.scheduler.config.max_image_seq_len,
756
+ self.scheduler.config.base_shift,
757
+ self.scheduler.config.max_shift,
758
+ )
759
+ # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
760
+ timesteps, num_inference_steps = retrieve_timesteps(
761
+ self.scheduler,
762
+ num_inference_steps,
763
+ device,
764
+ sigmas=sigmas,
765
+ mu=mu,
766
+ )
767
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
768
+
769
+ if num_inference_steps < 1:
770
+ raise ValueError(
771
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
772
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
773
+ )
774
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
775
+
776
+ # 5. Prepare latent variables
777
+ num_channels_latents = self.transformer.config.in_channels // 4
778
+
779
+ latents, latent_image_ids = self.prepare_latents(
780
+ init_image,
781
+ latent_timestep,
782
+ batch_size * num_images_per_prompt,
783
+ num_channels_latents,
784
+ height,
785
+ width,
786
+ prompt_embeds.dtype,
787
+ device,
788
+ generator,
789
+ latents,
790
+ )
791
+
792
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
793
+ self._num_timesteps = len(timesteps)
794
+
795
+ # handle guidance
796
+ if self.transformer.config.guidance_embeds:
797
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
798
+ guidance = guidance.expand(latents.shape[0])
799
+ else:
800
+ guidance = None
801
+
802
+ # 6. Denoising loop
803
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
804
+ for i, t in enumerate(timesteps):
805
+ if self.interrupt:
806
+ continue
807
+
808
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
809
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
810
+ noise_pred = self.transformer(
811
+ hidden_states=latents,
812
+ timestep=timestep / 1000,
813
+ guidance=guidance,
814
+ pooled_projections=pooled_prompt_embeds,
815
+ encoder_hidden_states=prompt_embeds,
816
+ txt_ids=text_ids,
817
+ img_ids=latent_image_ids,
818
+ joint_attention_kwargs=self.joint_attention_kwargs,
819
+ return_dict=False,
820
+ )[0]
821
+
822
+ # compute the previous noisy sample x_t -> x_t-1
823
+ latents_dtype = latents.dtype
824
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
825
+
826
+ if latents.dtype != latents_dtype:
827
+ if torch.backends.mps.is_available():
828
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
829
+ latents = latents.to(latents_dtype)
830
+
831
+ if callback_on_step_end is not None:
832
+ callback_kwargs = {}
833
+ for k in callback_on_step_end_tensor_inputs:
834
+ callback_kwargs[k] = locals()[k]
835
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
836
+
837
+ latents = callback_outputs.pop("latents", latents)
838
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
839
+
840
+ # call the callback, if provided
841
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
842
+ progress_bar.update()
843
+
844
+ if XLA_AVAILABLE:
845
+ xm.mark_step()
846
+
847
+ if output_type == "latent":
848
+ image = latents
849
+
850
+ else:
851
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
852
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
853
+ image = self.vae.decode(latents, return_dict=False)[0]
854
+ image = self.image_processor.postprocess(image, output_type=output_type)
855
+
856
+ # Offload all models
857
+ self.maybe_free_model_hooks()
858
+
859
+ if not return_dict:
860
+ return (image,)
861
+
862
+ return FluxPipelineOutput(images=image)
pipeline/custom_pipelines/pipeline_flux_prior_redux.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copied from diffusers/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py
2
+
3
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ import torch
21
+ from PIL import Image
22
+ from transformers import (
23
+ CLIPTextModel,
24
+ CLIPTokenizer,
25
+ SiglipImageProcessor,
26
+ SiglipVisionModel,
27
+ T5EncoderModel,
28
+ T5TokenizerFast,
29
+ )
30
+
31
+ from diffusers.image_processor import PipelineImageInput
32
+ from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
33
+ from diffusers.utils import (
34
+ USE_PEFT_BACKEND,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
42
+ from diffusers.pipelines.flux.modeling_flux import ReduxImageEncoder
43
+ from diffusers.pipelines.flux.pipeline_output import FluxPriorReduxPipelineOutput
44
+
45
+
46
+ if is_torch_xla_available():
47
+ XLA_AVAILABLE = True
48
+ else:
49
+ XLA_AVAILABLE = False
50
+
51
+
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+ EXAMPLE_DOC_STRING = """
55
+ Examples:
56
+ ```py
57
+ >>> import torch
58
+ >>> from diffusers import FluxPriorReduxPipeline, FluxPipeline
59
+ >>> from diffusers.utils import load_image
60
+
61
+ >>> device = "cuda"
62
+ >>> dtype = torch.bfloat16
63
+
64
+ >>> repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
65
+ >>> repo_base = "black-forest-labs/FLUX.1-dev"
66
+ >>> pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)
67
+ >>> pipe = FluxPipeline.from_pretrained(
68
+ ... repo_base, text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16
69
+ ... ).to(device)
70
+
71
+ >>> image = load_image(
72
+ ... "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png"
73
+ ... )
74
+ >>> pipe_prior_output = pipe_prior_redux(image)
75
+ >>> images = pipe(
76
+ ... guidance_scale=2.5,
77
+ ... num_inference_steps=50,
78
+ ... generator=torch.Generator("cpu").manual_seed(0),
79
+ ... **pipe_prior_output,
80
+ ... ).images
81
+ >>> images[0].save("flux-redux.png")
82
+ ```
83
+ """
84
+
85
+
86
+ class FluxPriorReduxPipeline(DiffusionPipeline):
87
+ r"""
88
+ The Flux Redux pipeline for image-to-image generation.
89
+
90
+ Reference: https://blackforestlabs.ai/flux-1-tools/
91
+
92
+ Args:
93
+ image_encoder ([`SiglipVisionModel`]):
94
+ SIGLIP vision model to encode the input image.
95
+ feature_extractor ([`SiglipImageProcessor`]):
96
+ Image processor for preprocessing images for the SIGLIP model.
97
+ image_embedder ([`ReduxImageEncoder`]):
98
+ Redux image encoder to process the SIGLIP embeddings.
99
+ text_encoder ([`CLIPTextModel`], *optional*):
100
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
101
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
102
+ text_encoder_2 ([`T5EncoderModel`], *optional*):
103
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
104
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
105
+ tokenizer (`CLIPTokenizer`, *optional*):
106
+ Tokenizer of class
107
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
108
+ tokenizer_2 (`T5TokenizerFast`, *optional*):
109
+ Second Tokenizer of class
110
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
111
+ """
112
+
113
+ model_cpu_offload_seq = "image_encoder->image_embedder"
114
+ _optional_components = [
115
+ "text_encoder",
116
+ "tokenizer",
117
+ "text_encoder_2",
118
+ "tokenizer_2",
119
+ ]
120
+ _callback_tensor_inputs = []
121
+
122
+ def __init__(
123
+ self,
124
+ image_encoder: SiglipVisionModel,
125
+ feature_extractor: SiglipImageProcessor,
126
+ image_embedder: ReduxImageEncoder,
127
+ text_encoder: CLIPTextModel = None,
128
+ tokenizer: CLIPTokenizer = None,
129
+ text_encoder_2: T5EncoderModel = None,
130
+ tokenizer_2: T5TokenizerFast = None,
131
+ ):
132
+ super().__init__()
133
+
134
+ self.register_modules(
135
+ image_encoder=image_encoder,
136
+ feature_extractor=feature_extractor,
137
+ image_embedder=image_embedder,
138
+ text_encoder=text_encoder,
139
+ tokenizer=tokenizer,
140
+ text_encoder_2=text_encoder_2,
141
+ tokenizer_2=tokenizer_2,
142
+ )
143
+ self.tokenizer_max_length = (
144
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
145
+ )
146
+
147
+ def check_inputs(
148
+ self,
149
+ image,
150
+ prompt,
151
+ prompt_2,
152
+ prompt_embeds=None,
153
+ pooled_prompt_embeds=None,
154
+ prompt_embeds_scale=1.0,
155
+ pooled_prompt_embeds_scale=1.0,
156
+ ):
157
+ if prompt is not None and prompt_embeds is not None:
158
+ raise ValueError(
159
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
160
+ " only forward one of the two."
161
+ )
162
+ elif prompt_2 is not None and prompt_embeds is not None:
163
+ raise ValueError(
164
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
165
+ " only forward one of the two."
166
+ )
167
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
168
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
169
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
170
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
171
+ if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)):
172
+ raise ValueError(
173
+ f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images"
174
+ )
175
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
176
+ raise ValueError(
177
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
178
+ )
179
+ if isinstance(prompt_embeds_scale, list) and (
180
+ isinstance(image, list) and len(prompt_embeds_scale) != len(image)
181
+ ):
182
+ raise ValueError(
183
+ f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images"
184
+ )
185
+
186
+ def encode_image(self, image, device, num_images_per_prompt):
187
+ dtype = next(self.image_encoder.parameters()).dtype
188
+ image = self.feature_extractor.preprocess(
189
+ images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True
190
+ )
191
+ image = image.to(device=device, dtype=dtype)
192
+
193
+ image_enc_hidden_states = self.image_encoder(**image).last_hidden_state
194
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
195
+
196
+ return image_enc_hidden_states
197
+
198
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
199
+ def _get_t5_prompt_embeds(
200
+ self,
201
+ prompt: Union[str, List[str]] = None,
202
+ num_images_per_prompt: int = 1,
203
+ max_sequence_length: int = 512,
204
+ device: Optional[torch.device] = None,
205
+ dtype: Optional[torch.dtype] = None,
206
+ ):
207
+ device = device or self._execution_device
208
+ dtype = dtype or self.text_encoder.dtype
209
+
210
+ prompt = [prompt] if isinstance(prompt, str) else prompt
211
+ batch_size = len(prompt)
212
+
213
+ if isinstance(self, TextualInversionLoaderMixin):
214
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
215
+
216
+ text_inputs = self.tokenizer_2(
217
+ prompt,
218
+ padding="max_length",
219
+ max_length=max_sequence_length,
220
+ truncation=True,
221
+ return_length=False,
222
+ return_overflowing_tokens=False,
223
+ return_tensors="pt",
224
+ )
225
+ text_input_ids = text_inputs.input_ids
226
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
227
+
228
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
229
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
230
+ logger.warning(
231
+ "The following part of your input was truncated because `max_sequence_length` is set to "
232
+ f" {max_sequence_length} tokens: {removed_text}"
233
+ )
234
+
235
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
236
+
237
+ dtype = self.text_encoder_2.dtype
238
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
239
+
240
+ _, seq_len, _ = prompt_embeds.shape
241
+
242
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
243
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
244
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
245
+
246
+ return prompt_embeds
247
+
248
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
249
+ def _get_clip_prompt_embeds(
250
+ self,
251
+ prompt: Union[str, List[str]],
252
+ num_images_per_prompt: int = 1,
253
+ device: Optional[torch.device] = None,
254
+ ):
255
+ device = device or self._execution_device
256
+
257
+ prompt = [prompt] if isinstance(prompt, str) else prompt
258
+ batch_size = len(prompt)
259
+
260
+ if isinstance(self, TextualInversionLoaderMixin):
261
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
262
+
263
+ text_inputs = self.tokenizer(
264
+ prompt,
265
+ padding="max_length",
266
+ max_length=self.tokenizer_max_length,
267
+ truncation=True,
268
+ return_overflowing_tokens=False,
269
+ return_length=False,
270
+ return_tensors="pt",
271
+ )
272
+
273
+ text_input_ids = text_inputs.input_ids
274
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
275
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
276
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
277
+ logger.warning(
278
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
279
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
280
+ )
281
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
282
+
283
+ # Use pooled output of CLIPTextModel
284
+ prompt_embeds = prompt_embeds.pooler_output
285
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
286
+
287
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
288
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
289
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
290
+
291
+ return prompt_embeds
292
+
293
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
294
+ def encode_prompt(
295
+ self,
296
+ prompt: Union[str, List[str]],
297
+ prompt_2: Union[str, List[str]],
298
+ device: Optional[torch.device] = None,
299
+ num_images_per_prompt: int = 1,
300
+ prompt_embeds: Optional[torch.FloatTensor] = None,
301
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
302
+ max_sequence_length: int = 512,
303
+ lora_scale: Optional[float] = None,
304
+ ):
305
+ r"""
306
+
307
+ Args:
308
+ prompt (`str` or `List[str]`, *optional*):
309
+ prompt to be encoded
310
+ prompt_2 (`str` or `List[str]`, *optional*):
311
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
312
+ used in all text-encoders
313
+ device: (`torch.device`):
314
+ torch device
315
+ num_images_per_prompt (`int`):
316
+ number of images that should be generated per prompt
317
+ prompt_embeds (`torch.FloatTensor`, *optional*):
318
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
319
+ provided, text embeddings will be generated from `prompt` input argument.
320
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
321
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
322
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
323
+ lora_scale (`float`, *optional*):
324
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
325
+ """
326
+ device = device or self._execution_device
327
+
328
+ # set lora scale so that monkey patched LoRA
329
+ # function of text encoder can correctly access it
330
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
331
+ self._lora_scale = lora_scale
332
+
333
+ # dynamically adjust the LoRA scale
334
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
335
+ scale_lora_layers(self.text_encoder, lora_scale)
336
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
337
+ scale_lora_layers(self.text_encoder_2, lora_scale)
338
+
339
+ prompt = [prompt] if isinstance(prompt, str) else prompt
340
+
341
+ if prompt_embeds is None:
342
+ prompt_2 = prompt_2 or prompt
343
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
344
+
345
+ # We only use the pooled prompt output from the CLIPTextModel
346
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
347
+ prompt=prompt,
348
+ device=device,
349
+ num_images_per_prompt=num_images_per_prompt,
350
+ )
351
+ prompt_embeds = self._get_t5_prompt_embeds(
352
+ prompt=prompt_2,
353
+ num_images_per_prompt=num_images_per_prompt,
354
+ max_sequence_length=max_sequence_length,
355
+ device=device,
356
+ )
357
+
358
+ if self.text_encoder is not None:
359
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
360
+ # Retrieve the original scale by scaling back the LoRA layers
361
+ unscale_lora_layers(self.text_encoder, lora_scale)
362
+
363
+ if self.text_encoder_2 is not None:
364
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
365
+ # Retrieve the original scale by scaling back the LoRA layers
366
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
367
+
368
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
369
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
370
+
371
+ return prompt_embeds, pooled_prompt_embeds, text_ids
372
+
373
+ @torch.no_grad()
374
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
375
+ def __call__(
376
+ self,
377
+ image: PipelineImageInput,
378
+ prompt: Union[str, List[str]] = None,
379
+ prompt_2: Optional[Union[str, List[str]]] = None,
380
+ prompt_embeds: Optional[torch.FloatTensor] = None,
381
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
382
+ prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
383
+ pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
384
+ strength: Optional[Union[float, List[float]]] = 1.0,
385
+ return_dict: bool = True,
386
+ ):
387
+ r"""
388
+ Function invoked when calling the pipeline for generation.
389
+
390
+ Args:
391
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
392
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
393
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
394
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
395
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
396
+ prompt (`str` or `List[str]`, *optional*):
397
+ The prompt or prompts to guide the image generation. **experimental feature**: to use this feature,
398
+ make sure to explicitly load text encoders to the pipeline. Prompts will be ignored if text encoders
399
+ are not loaded.
400
+ prompt_2 (`str` or `List[str]`, *optional*):
401
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
402
+ prompt_embeds (`torch.FloatTensor`, *optional*):
403
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
404
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
405
+ Pre-generated pooled text embeddings.
406
+ return_dict (`bool`, *optional*, defaults to `True`):
407
+ Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple.
408
+
409
+ Examples:
410
+
411
+ Returns:
412
+ [`~pipelines.flux.FluxPriorReduxPipelineOutput`] or `tuple`:
413
+ [`~pipelines.flux.FluxPriorReduxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
414
+ returning a tuple, the first element is a list with the generated images.
415
+ """
416
+
417
+ # 1. Check inputs. Raise error if not correct
418
+ self.check_inputs(
419
+ image,
420
+ prompt,
421
+ prompt_2,
422
+ prompt_embeds=prompt_embeds,
423
+ pooled_prompt_embeds=pooled_prompt_embeds,
424
+ prompt_embeds_scale=prompt_embeds_scale,
425
+ pooled_prompt_embeds_scale=pooled_prompt_embeds_scale,
426
+ )
427
+
428
+ # 2. Define call parameters
429
+ if image is not None and isinstance(image, Image.Image):
430
+ batch_size = 1
431
+ elif image is not None and isinstance(image, list):
432
+ batch_size = len(image)
433
+ else:
434
+ batch_size = image.shape[0]
435
+ if prompt is not None and isinstance(prompt, str):
436
+ prompt = batch_size * [prompt]
437
+ if isinstance(prompt_embeds_scale, float):
438
+ prompt_embeds_scale = batch_size * [prompt_embeds_scale]
439
+ if isinstance(pooled_prompt_embeds_scale, float):
440
+ pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale]
441
+ if isinstance(strength, float):
442
+ strength = batch_size * [strength]
443
+
444
+ device = self._execution_device
445
+
446
+ # 3. Prepare image embeddings
447
+ image_latents = self.encode_image(image, device, 1)
448
+
449
+ image_embeds = self.image_embedder(image_latents).image_embeds
450
+ image_embeds = image_embeds.to(device=device)
451
+
452
+ # 3. Prepare (dummy) text embeddings
453
+ if hasattr(self, "text_encoder") and self.text_encoder is not None:
454
+ (
455
+ prompt_embeds,
456
+ pooled_prompt_embeds,
457
+ _,
458
+ ) = self.encode_prompt(
459
+ prompt=prompt,
460
+ prompt_2=prompt_2,
461
+ prompt_embeds=prompt_embeds,
462
+ pooled_prompt_embeds=pooled_prompt_embeds,
463
+ device=device,
464
+ num_images_per_prompt=1,
465
+ max_sequence_length=512,
466
+ lora_scale=None,
467
+ )
468
+ else:
469
+ if prompt is not None:
470
+ logger.warning(
471
+ "prompt input is ignored when text encoders are not loaded to the pipeline. "
472
+ "Make sure to explicitly load the text encoders to enable prompt input. "
473
+ )
474
+ # max_sequence_length is 512, t5 encoder hidden size is 4096
475
+ prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype)
476
+ # pooled_prompt_embeds is 768, clip text encoder hidden size
477
+ pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype)
478
+
479
+ # apply strength to image_embeds
480
+ image_embeds *= torch.tensor(strength, device=device, dtype=image_embeds.dtype)[:, None, None]
481
+
482
+ # scale & concatenate image and text embeddings
483
+ prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1)
484
+
485
+ prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None]
486
+ pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[
487
+ :, None
488
+ ]
489
+
490
+ # weighted sum
491
+ prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True)
492
+ pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True)
493
+
494
+ # Offload all models
495
+ self.maybe_free_model_hooks()
496
+
497
+ if not return_dict:
498
+ return (prompt_embeds, pooled_prompt_embeds)
499
+
500
+ return FluxPriorReduxPipelineOutput(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds)
pipeline/example_text_to_3d.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from pipeline.kiss3d_wrapper import init_wrapper_from_config, run_text_to_3d, run_image_to_3d
2
+
3
+ if __name__ == "__main__":
4
+ k3d_wrapper = init_wrapper_from_config('/hpc2hdd/home/jlin695/code/github/Kiss3DGen/pipeline/pipeline_config/default.yaml')
5
+
6
+ run_text_to_3d(k3d_wrapper, prompt='A doll of a girl in Harry Potter')
7
+
pipeline/kiss3d_wrapper.py CHANGED
@@ -1,7 +1,9 @@
1
  # The kiss3d pipeline wrapper for inference
2
 
3
  import os
 
4
  import numpy as np
 
5
  import torch
6
  import yaml
7
  import uuid
@@ -10,49 +12,93 @@ from einops import rearrange
10
  from PIL import Image
11
 
12
  from pipeline.utils import logger, TMP_DIR, OUT_DIR
13
- from pipeline.utils import lrm_reconstruct, isomer_reconstruct
14
 
15
  import torch
16
  import torchvision
 
17
 
18
  # for reconstruction model
19
  from omegaconf import OmegaConf
20
  from models.lrm.utils.train_util import instantiate_from_config
21
  from models.lrm.utils.render_utils import rotate_x, rotate_y
 
22
  from utils.tool import get_background
23
-
24
  # for florence2
25
- from transformers import AutoProcessor, AutoModelForCausalLM
26
-
27
- from diffusers import FluxPipeline, FluxControlNetImg2ImgPipeline, FluxImg2ImgPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
28
- from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
29
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
31
  def init_wrapper_from_config(config_path):
32
  with open(config_path, 'r') as config_file:
33
  config_ = yaml.load(config_file, yaml.FullLoader)
 
 
 
 
 
 
 
34
 
35
  # init flux_pipeline
36
  logger.info('==> Loading Flux model ...')
37
  flux_device = config_['flux'].get('device', 'cpu')
38
  flux_base_model_pth = config_['flux'].get('base_model', None)
 
39
  flux_controlnet_pth = config_['flux'].get('controlnet', None)
40
- flux_lora_pth = config_['flux'].get('lora', None)
41
-
42
- # load flux model and controlnet
43
- if flux_controlnet_pth is not None:
44
- flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth)
45
- flux_pipe = FluxControlNetImg2ImgPipeline.from_pretrained(flux_base_model_pth, controlnet=[flux_controlnet], \
46
- torch_dtype=torch.bfloat16)
47
  else:
48
- flux_pipe = FluxImg2ImgPipeline(flux_base_model_pth, torch_dtype=torch.bfloat16)
49
 
 
 
 
 
 
 
 
50
  # load lora weights
51
  flux_pipe.load_lora_weights(flux_lora_pth)
52
- flux_pipe.to(device=flux_device, dtype=torch.bfloat16)
53
 
54
- # TODO: load redux model
55
- # FluxPriorReduxPipeline.from_pretrained()
 
 
 
 
 
 
 
 
 
 
56
 
57
  # TODO: load pulid model
58
 
@@ -68,13 +114,15 @@ def init_wrapper_from_config(config_path):
68
  multiview_pipeline.scheduler.config, timestep_spacing='trailing'
69
  )
70
 
71
- unet_ckpt_path = config_['multiview'].get('unet', None)
 
72
  if unet_ckpt_path is not None:
73
  state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
74
  state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
75
  multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
76
 
77
  multiview_pipeline.to(multiview_device)
 
78
 
79
  # load caption model
80
  logger.info('==> Loading caption model ...')
@@ -82,6 +130,7 @@ def init_wrapper_from_config(config_path):
82
  caption_model = AutoModelForCausalLM.from_pretrained(config_['caption']['base_model'], \
83
  torch_dtype=torch.bfloat16, trust_remote_code=True).to(caption_device)
84
  caption_processor = AutoProcessor.from_pretrained(config_['caption']['base_model'], trust_remote_code=True)
 
85
 
86
  # load reconstruction model
87
  logger.info('==> Loading reconstruction model ...')
@@ -89,40 +138,79 @@ def init_wrapper_from_config(config_path):
89
  recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
90
  recon_model = instantiate_from_config(recon_model_config.model_config)
91
  # load recon model checkpoint
92
- state_dict = torch.load(config_['reconstruction']['base_model'], map_location='cpu')['state_dict']
 
93
  state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
94
  recon_model.load_state_dict(state_dict, strict=True)
95
  recon_model.to(recon_device)
96
  recon_model.init_flexicubes_geometry(recon_device, fovy=50.0)
97
  recon_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  return kiss3d_wrapper(
100
  config = config_,
101
  flux_pipeline = flux_pipe,
 
102
  multiview_pipeline = multiview_pipeline,
103
  caption_processor = caption_processor,
104
  caption_model = caption_model,
105
  reconstruction_model_config = recon_model_config,
106
  reconstruction_model = recon_model,
 
 
107
  )
108
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  class kiss3d_wrapper(object):
110
  def __init__(self,
111
  config: Dict,
112
  flux_pipeline: Union[FluxPipeline, FluxControlNetImg2ImgPipeline],
 
113
  multiview_pipeline: DiffusionPipeline,
114
  caption_processor: AutoProcessor,
115
  caption_model: AutoModelForCausalLM,
116
  reconstruction_model_config: Any,
117
  reconstruction_model: Any,
 
 
118
  ):
119
  self.config = config
120
  self.flux_pipeline = flux_pipeline
 
121
  self.multiview_pipeline = multiview_pipeline
122
  self.caption_model = caption_model
123
  self.caption_processor = caption_processor
124
  self.recon_model_config = reconstruction_model_config
125
- self.recon_model = reconstruction_model
 
 
 
 
 
 
 
126
 
127
  self.renew_uuid()
128
 
@@ -144,12 +232,10 @@ class kiss3d_wrapper(object):
144
  caption_device = self.config['caption'].get('device', 'cpu')
145
 
146
  if isinstance(image, str): # If image is a file path
147
- image = Image.open(image).convert("RGB")
148
- elif isinstance(image, Image):
149
- image = image.convert("RGB")
150
- else:
151
  raise NotImplementedError('unexpected image type')
152
-
153
  prompt = "<MORE_DETAILED_CAPTION>"
154
  inputs = self.caption_processor(text=prompt, images=image, return_tensors="pt").to(caption_device, torch_dtype)
155
 
@@ -161,17 +247,45 @@ class kiss3d_wrapper(object):
161
  parsed_answer = self.caption_processor.post_process_generation(
162
  generated_text, task=prompt, image_size=(image.width, image.height)
163
  )
164
- caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"].replace("The image is ", "")
 
 
 
 
 
165
  return caption_text
166
 
167
- def generate_multiview(self, image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  with self.context():
169
  mv_image = self.multiview_pipeline(image,
170
- num_inference_steps=self.config['multiview']['num_inference_steps'],
171
- width=512*2, height=512*2).images[0]
 
 
172
  return mv_image
173
 
174
- def reconstruct_from_multiview(self, mv_image):
175
  """
176
  mv_image: PIL.Image
177
  """
@@ -184,23 +298,31 @@ class kiss3d_wrapper(object):
184
  with self.context():
185
  vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \
186
  lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config,
187
- rgb_multi_view, name=self.uuid)
188
 
189
- return vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo
190
 
191
- def generate_reference_3D_bundle_image_zero123(self, image, save_intermediate_results=True):
192
  """
193
  input: image, PIL.Image
194
- return: ref_3D_bundle_image, Tensor of shape (1, 3, 1024, 2048)
195
  """
196
  mv_image = self.generate_multiview(image)
197
 
198
  if save_intermediate_results:
199
  mv_image.save(os.path.join(TMP_DIR, f'{self.uuid}_mv_image.png'))
200
 
201
- vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = self.reconstruct_from_multiview(mv_image)
 
 
 
202
 
203
- ref_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
 
 
 
 
 
204
 
205
  if save_intermediate_results:
206
  save_path = os.path.join(TMP_DIR, f'{self.uuid}_ref_3d_bundle_image.png')
@@ -222,6 +344,9 @@ class kiss3d_wrapper(object):
222
  control_guidance_end=None,
223
  controlnet_conditioning_scale=None,
224
  lora_scale=1.0,
 
 
 
225
  save_intermediate_results=True,
226
  **kwargs):
227
  control_mode_dict = {
@@ -235,15 +360,20 @@ class kiss3d_wrapper(object):
235
  } # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
236
 
237
  flux_device = self.config['flux'].get('device', 'cpu')
238
- seed = self.config['flux'].get('seed', 0)
 
239
 
240
  generator = torch.Generator(device=flux_device).manual_seed(seed)
241
 
 
 
 
242
  hparam_dict = {
243
- 'prompt': ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', prompt]),
244
- 'image': image or torch.zeros((1, 3, 1024, 2048), dtype=torch.float32, device=flux_device),
 
245
  'strength': strength,
246
- 'num_inference_steps': 30,
247
  'guidance_scale': 3.5,
248
  'num_images_per_prompt': 1,
249
  'width': 2048,
@@ -253,14 +383,29 @@ class kiss3d_wrapper(object):
253
  'joint_attention_kwargs': {"scale": lora_scale}
254
  }
255
  hparam_dict.update(kwargs)
256
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  # append controlnet hparams
258
  if len(control_image) > 0:
259
  assert isinstance(self.flux_pipeline, FluxControlNetImg2ImgPipeline)
260
  assert len(control_mode) == len(control_image) # the count of image should be the same as control mode
261
 
262
  flux_ctrl_net = self.flux_pipeline.controlnet.nets[0]
263
- self.flux_pipeline.controlnet = FluxMultiControlNetModel([flux_ctrl_net for i in range(len(control_image))])
264
 
265
  ctrl_hparams = {
266
  'control_mode': [control_mode_dict[mode_] for mode_ in control_mode],
@@ -285,13 +430,45 @@ class kiss3d_wrapper(object):
285
 
286
  return gen_3d_bundle_image_
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  def generate_3d_bundle_image_text(self,
290
  prompt,
291
  image=None,
292
  strength=1.0,
293
  lora_scale=1.0,
294
- num_inference_steps=30,
 
 
295
  save_intermediate_results=True,
296
  **kwargs):
297
 
@@ -299,27 +476,25 @@ class kiss3d_wrapper(object):
299
  return: gen_3d_bundle_image, torch.Tensor of shape (3, 1024, 2048), range [0., 1.]
300
  """
301
 
302
- if isinstance(self.flux_pipeline, FluxControlNetImg2ImgPipeline):
303
- flux_pipeline = FluxImg2ImgPipeline(
304
- scheduler = self.flux_pipeline.scheduler,
305
- vae = self.flux_pipeline.vae,
306
- text_encoder = self.flux_pipeline.text_encoder,
307
- tokenizer = self.flux_pipeline.tokenizer,
308
- text_encoder_2 = self.flux_pipeline.text_encoder_2,
309
- tokenizer_2 = self.flux_pipeline.tokenizer_2,
310
- transformer = self.flux_pipeline.transformer
311
- )
312
- else:
313
  flux_pipeline = self.flux_pipeline
 
 
314
 
315
  flux_device = self.config['flux'].get('device', 'cpu')
316
- seed = self.config['flux'].get('seed', 0)
 
 
 
 
317
 
318
  generator = torch.Generator(device=flux_device).manual_seed(seed)
319
 
 
320
  hparam_dict = {
321
- 'prompt': ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', prompt]),
322
- 'image': image or torch.zeros((1, 3, 1024, 2048), dtype=torch.float32, device=flux_device),
 
323
  'strength': strength,
324
  'num_inference_steps': num_inference_steps,
325
  'guidance_scale': 3.5,
@@ -332,6 +507,22 @@ class kiss3d_wrapper(object):
332
  }
333
  hparam_dict.update(kwargs)
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  with self.context():
336
  gen_3d_bundle_image = flux_pipeline(**hparam_dict).images
337
 
@@ -345,7 +536,13 @@ class kiss3d_wrapper(object):
345
 
346
  return gen_3d_bundle_image_
347
 
348
- def reconstruct_3d_bundle_image(self, image, save_intermediate_results=True):
 
 
 
 
 
 
349
  """
350
  image: torch.Tensor, range [0., 1.], (3, 1024, 2048)
351
  """
@@ -355,6 +552,8 @@ class kiss3d_wrapper(object):
355
  images = rearrange(image, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (3, 1024, 2048) -> (8, 3, 512, 512)
356
  rgb_multi_view, normal_multi_view = images.chunk(2, dim=0)
357
  multi_view_mask = get_background(normal_multi_view).to(recon_device)
 
 
358
  rgb_multi_view = rgb_multi_view.to(recon_device) * multi_view_mask + (1 - multi_view_mask)
359
 
360
  with self.context():
@@ -362,11 +561,12 @@ class kiss3d_wrapper(object):
362
  lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config,
363
  rgb_multi_view.unsqueeze(0).to(recon_device), name=self.uuid,
364
  input_camera_type='kiss3d', render_3d_bundle_image=save_intermediate_results,
365
- render_azimuths=[0, 90, 180, 270])
 
366
 
367
  if save_intermediate_results:
368
  recon_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
369
- torchvision.utils.save_image(recon_3D_bundle_image, os.path.join(TMP_DIR, f'{k3d_wrapper.uuid})_lrm_recon_3d_bundle_image.png'))
370
 
371
  recon_mesh_path = os.path.join(TMP_DIR, f"{self.uuid}_isomer_recon_mesh.obj")
372
 
@@ -375,7 +575,11 @@ class kiss3d_wrapper(object):
375
  multi_view_mask=multi_view_mask,
376
  vertices=vertices,
377
  faces=faces,
378
- save_path=recon_mesh_path)
 
 
 
 
379
 
380
 
381
  def run_text_to_3d(k3d_wrapper,
@@ -391,39 +595,176 @@ def run_text_to_3d(k3d_wrapper,
391
  if init_image_path is not None:
392
  init_image = Image.open(init_image_path)
393
 
 
 
 
 
 
394
  gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_text(prompt,
395
- image=init_image,
396
- strength=1.0,
397
- save_intermediate_results=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
  # recon from 3D Bundle image
400
  recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, save_intermediate_results=False)
401
 
402
  return gen_save_path, recon_mesh_path
403
 
404
- def run_image_to_3d(k3d_wrapper, init_image_path):
 
405
  # ======================================= Example of image to 3D generation ======================================
406
 
407
  # Renew The uuid
408
  k3d_wrapper.renew_uuid()
409
 
410
  # FOR IMAGE TO 3D: generate reference 3D bundle image from a single input image
411
- input_image = Image.open(init_image_path)
412
- reference_3d_bundle_image, reference_save_path = k3d_wrapper.generate_reference_3D_bundle_image_zero123(input_image)
 
 
413
  caption = k3d_wrapper.get_image_caption(input_image)
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
- import pdb
417
- pdb.set_trace()
418
 
419
 
420
  if __name__ == "__main__":
421
- k3d_wrapper = init_wrapper_from_config('/hpc2hdd/home/jlin695/code/Kiss3DGen/pipeline/pipeline_config/default.yaml')
 
 
 
422
 
423
- # Example of loading existing 3D bundle Image
424
- # demo_image = Image.open('/hpc2hdd/home/jlin695/code/github/Kiss3DGen/outputs/tmp/ea25bc9b-d775-46bb-9827-660a9a6540c8_gen_3d_bundle_image.png')
425
- # gen_3d_bundle_image = torchvision.transforms.functional.to_tensor(demo_image)
426
 
427
- run_image_to_3d(k3d_wrapper, '/hpc2hdd/home/jlin695/code/Kiss3DGen/examples/蓝色小怪物.webp')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  # run_text_to_3d(k3d_wrapper, prompt='A doll of a girl in Harry Potter')
429
 
 
 
 
 
 
1
  # The kiss3d pipeline wrapper for inference
2
 
3
  import os
4
+ import spaces
5
  import numpy as np
6
+ import random
7
  import torch
8
  import yaml
9
  import uuid
 
12
  from PIL import Image
13
 
14
  from pipeline.utils import logger, TMP_DIR, OUT_DIR
15
+ from pipeline.utils import lrm_reconstruct, isomer_reconstruct, preprocess_input_image
16
 
17
  import torch
18
  import torchvision
19
+ from torch.nn import functional as F
20
 
21
  # for reconstruction model
22
  from omegaconf import OmegaConf
23
  from models.lrm.utils.train_util import instantiate_from_config
24
  from models.lrm.utils.render_utils import rotate_x, rotate_y
25
+ #
26
  from utils.tool import get_background
 
27
  # for florence2
28
+ from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer
29
+ from models.llm.llm import load_llm_model, get_llm_response
30
+
31
+ from pipeline.custom_pipelines import FluxPriorReduxPipeline, FluxControlNetImg2ImgPipeline, FluxImg2ImgPipeline
32
+ from diffusers import FluxPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler, FluxTransformer2DModel
33
+ from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel, FluxControlNetModel
34
+ from diffusers.schedulers import FlowMatchHeunDiscreteScheduler
35
+ from huggingface_hub import hf_hub_download
36
+ access_token = os.getenv("HUGGINGFACE_TOKEN")
37
+
38
+
39
+ def convert_flux_pipeline(exist_flux_pipe, target_pipe, **kwargs):
40
+ new_pipe = target_pipe(
41
+ scheduler = exist_flux_pipe.scheduler,
42
+ vae = exist_flux_pipe.vae,
43
+ text_encoder = exist_flux_pipe.text_encoder,
44
+ tokenizer = exist_flux_pipe.tokenizer,
45
+ text_encoder_2 = exist_flux_pipe.text_encoder_2,
46
+ tokenizer_2 = exist_flux_pipe.tokenizer_2,
47
+ transformer = exist_flux_pipe.transformer,
48
+ **kwargs
49
+ )
50
+ return new_pipe
51
 
52
+ @spaces.GPU
53
  def init_wrapper_from_config(config_path):
54
  with open(config_path, 'r') as config_file:
55
  config_ = yaml.load(config_file, yaml.FullLoader)
56
+
57
+ dtype_ = {
58
+ 'fp8': torch.float8_e4m3fn,
59
+ 'bf16': torch.bfloat16,
60
+ 'fp16': torch.float16,
61
+ 'fp32': torch.float32
62
+ }
63
 
64
  # init flux_pipeline
65
  logger.info('==> Loading Flux model ...')
66
  flux_device = config_['flux'].get('device', 'cpu')
67
  flux_base_model_pth = config_['flux'].get('base_model', None)
68
+ flux_dtype = config_['flux'].get('dtype', 'bf16')
69
  flux_controlnet_pth = config_['flux'].get('controlnet', None)
70
+ # flux_lora_pth = config_['flux'].get('lora', None)
71
+ flux_lora_pth = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
72
+ flux_redux_pth = config_['flux'].get('redux', None)
73
+
74
+ if flux_base_model_pth.endswith('safetensors'):
75
+ flux_pipe = FluxImg2ImgPipeline.from_single_file(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
 
76
  else:
77
+ flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
78
 
79
+ # load flux model and controlnet
80
+ if flux_controlnet_pth is not None:
81
+ flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
82
+ flux_pipe = convert_flux_pipeline(flux_pipe, FluxControlNetImg2ImgPipeline, controlnet=[flux_controlnet])
83
+
84
+ flux_pipe.scheduler = FlowMatchHeunDiscreteScheduler.from_config(flux_pipe.scheduler.config)
85
+
86
  # load lora weights
87
  flux_pipe.load_lora_weights(flux_lora_pth)
88
+ flux_pipe.to(device=flux_device)
89
 
90
+ # load redux model
91
+ flux_redux_pipe = None
92
+ if flux_redux_pth is not None:
93
+ flux_redux_pipe = FluxPriorReduxPipeline.from_pretrained(flux_redux_pth, torch_dtype=torch.bfloat16)
94
+ flux_redux_pipe.text_encoder = flux_pipe.text_encoder
95
+ flux_redux_pipe.text_encoder_2 = flux_pipe.text_encoder_2
96
+ flux_redux_pipe.tokenizer = flux_pipe.tokenizer
97
+ flux_redux_pipe.tokenizer_2 = flux_pipe.tokenizer_2
98
+
99
+ flux_redux_pipe.to(device=flux_device)
100
+
101
+ logger.warning(f"GPU memory allocated after load flux model on {flux_device}: {torch.cuda.memory_allocated(device=flux_device) / 1024**3} GB")
102
 
103
  # TODO: load pulid model
104
 
 
114
  multiview_pipeline.scheduler.config, timestep_spacing='trailing'
115
  )
116
 
117
+ # unet_ckpt_path = config_['multiview'].get('unet', None)
118
+ unet_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="flexgen_19w.ckpt", repo_type="model")
119
  if unet_ckpt_path is not None:
120
  state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
121
  state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
122
  multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
123
 
124
  multiview_pipeline.to(multiview_device)
125
+ logger.warning(f"GPU memory allocated after load multiview model on {multiview_device}: {torch.cuda.memory_allocated(device=multiview_device) / 1024**3} GB")
126
 
127
  # load caption model
128
  logger.info('==> Loading caption model ...')
 
130
  caption_model = AutoModelForCausalLM.from_pretrained(config_['caption']['base_model'], \
131
  torch_dtype=torch.bfloat16, trust_remote_code=True).to(caption_device)
132
  caption_processor = AutoProcessor.from_pretrained(config_['caption']['base_model'], trust_remote_code=True)
133
+ logger.warning(f"GPU memory allocated after load caption model on {caption_device}: {torch.cuda.memory_allocated(device=caption_device) / 1024**3} GB")
134
 
135
  # load reconstruction model
136
  logger.info('==> Loading reconstruction model ...')
 
138
  recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
139
  recon_model = instantiate_from_config(recon_model_config.model_config)
140
  # load recon model checkpoint
141
+ model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
142
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
143
  state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
144
  recon_model.load_state_dict(state_dict, strict=True)
145
  recon_model.to(recon_device)
146
  recon_model.init_flexicubes_geometry(recon_device, fovy=50.0)
147
  recon_model.eval()
148
+ logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
149
+
150
+ # load llm
151
+ llm_configs = config_.get('llm', None)
152
+ if llm_configs is not None:
153
+ logger.info('==> Loading LLM ...')
154
+ llm_device = llm_configs.get('device', 'cpu')
155
+ llm, llm_tokenizer = load_llm_model(llm_configs['base_model'])
156
+ llm.to(llm_device)
157
+ logger.warning(f"GPU memory allocated after load llm model on {llm_device}: {torch.cuda.memory_allocated(device=llm_device) / 1024**3} GB")
158
+ else:
159
+ llm, llm_tokenizer = None, None
160
 
161
  return kiss3d_wrapper(
162
  config = config_,
163
  flux_pipeline = flux_pipe,
164
+ flux_redux_pipeline=flux_redux_pipe,
165
  multiview_pipeline = multiview_pipeline,
166
  caption_processor = caption_processor,
167
  caption_model = caption_model,
168
  reconstruction_model_config = recon_model_config,
169
  reconstruction_model = recon_model,
170
+ llm_model = llm,
171
+ llm_tokenizer = llm_tokenizer
172
  )
173
 
174
+ def seed_everything(seed):
175
+
176
+ random.seed(seed)
177
+ np.random.seed(seed)
178
+ torch.manual_seed(seed)
179
+ torch.cuda.manual_seed(seed)
180
+ torch.cuda.manual_seed_all(seed)
181
+ torch.backends.cudnn.deterministic = True
182
+ torch.backends.cudnn.benchmark = False
183
+
184
+ print(f"Random seed set to {seed}")
185
+
186
  class kiss3d_wrapper(object):
187
  def __init__(self,
188
  config: Dict,
189
  flux_pipeline: Union[FluxPipeline, FluxControlNetImg2ImgPipeline],
190
+ flux_redux_pipeline: FluxPriorReduxPipeline,
191
  multiview_pipeline: DiffusionPipeline,
192
  caption_processor: AutoProcessor,
193
  caption_model: AutoModelForCausalLM,
194
  reconstruction_model_config: Any,
195
  reconstruction_model: Any,
196
+ llm_model: AutoModelForCausalLM = None,
197
+ llm_tokenizer: AutoTokenizer = None
198
  ):
199
  self.config = config
200
  self.flux_pipeline = flux_pipeline
201
+ self.flux_redux_pipeline = flux_redux_pipeline
202
  self.multiview_pipeline = multiview_pipeline
203
  self.caption_model = caption_model
204
  self.caption_processor = caption_processor
205
  self.recon_model_config = reconstruction_model_config
206
+ self.recon_model = reconstruction_model
207
+ self.llm_model = llm_model
208
+ self.llm_tokenizer = llm_tokenizer
209
+
210
+ self.to_512_tensor = torchvision.transforms.Compose([
211
+ torchvision.transforms.ToTensor(),
212
+ torchvision.transforms.Resize((512, 512), interpolation=2),
213
+ ])
214
 
215
  self.renew_uuid()
216
 
 
232
  caption_device = self.config['caption'].get('device', 'cpu')
233
 
234
  if isinstance(image, str): # If image is a file path
235
+ image = preprocess_input_image(Image.open(image))
236
+ elif not isinstance(image, Image.Image):
 
 
237
  raise NotImplementedError('unexpected image type')
238
+
239
  prompt = "<MORE_DETAILED_CAPTION>"
240
  inputs = self.caption_processor(text=prompt, images=image, return_tensors="pt").to(caption_device, torch_dtype)
241
 
 
247
  parsed_answer = self.caption_processor.post_process_generation(
248
  generated_text, task=prompt, image_size=(image.width, image.height)
249
  )
250
+ caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"] # .replace("The image is ", "")
251
+
252
+ logger.info(f"Auto caption result: \"{caption_text}\"")
253
+
254
+ caption_text = self.get_detailed_prompt(caption_text)
255
+
256
  return caption_text
257
 
258
+ def get_detailed_prompt(self, prompt, seed=None):
259
+ if self.llm_model is not None:
260
+ detailed_prompt = get_llm_response(self.llm_model, self.llm_tokenizer, prompt, seed=seed)
261
+
262
+ logger.info(f"LLM refined prompt result: \"{detailed_prompt}\"")
263
+ return detailed_prompt
264
+ return prompt
265
+
266
+ def del_llm_model(self):
267
+ logger.warning('This function is now deprecated and will take no effect')
268
+
269
+ # raise NotImplementedError()
270
+ # del llm.model
271
+ # del llm.tokenizer
272
+ # llm.model = None
273
+ # llm.tokenizer = None
274
+
275
+ def generate_multiview(self, image, seed=None, num_inference_steps=None):
276
+ seed = seed or self.config['multiview'].get('seed', 0)
277
+ mv_device = self.config['multiview'].get('device', 'cpu')
278
+
279
+ generator = torch.Generator(device=mv_device).manual_seed(seed)
280
  with self.context():
281
  mv_image = self.multiview_pipeline(image,
282
+ num_inference_steps=num_inference_steps or self.config['multiview']['num_inference_steps'],
283
+ width=512*2,
284
+ height=512*2,
285
+ generator=generator).images[0]
286
  return mv_image
287
 
288
+ def reconstruct_from_multiview(self, mv_image, lrm_render_radius=4.15):
289
  """
290
  mv_image: PIL.Image
291
  """
 
298
  with self.context():
299
  vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \
300
  lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config,
301
+ rgb_multi_view, name=self.uuid, render_radius=lrm_render_radius)
302
 
303
+ return rgb_multi_view, vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo
304
 
305
+ def generate_reference_3D_bundle_image_zero123(self, image, use_mv_rgb=False, save_intermediate_results=True):
306
  """
307
  input: image, PIL.Image
308
+ return: ref_3D_bundle_image, Tensor of shape (3, 1024, 2048)
309
  """
310
  mv_image = self.generate_multiview(image)
311
 
312
  if save_intermediate_results:
313
  mv_image.save(os.path.join(TMP_DIR, f'{self.uuid}_mv_image.png'))
314
 
315
+ rgb_multi_view, vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = self.reconstruct_from_multiview(mv_image)
316
+
317
+ if use_mv_rgb:
318
+ # ref_3D_bundle_image = torchvision.utils.make_grid(torch.cat([rgb_multi_view[0, [3, 0, 1, 2], ...].cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0) # range [0, 1]
319
 
320
+ rgb_ = torch.cat([rgb_multi_view[0, [3, 0, 1, 2], ...].cpu(), lrm_multi_view_rgb.cpu()], dim=0)
321
+ ref_3D_bundle_image = torchvision.utils.make_grid(torch.cat([rgb_[[0, 5, 2, 7], ...], (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0) # range [0, 1]
322
+ else:
323
+ ref_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0) # range [0, 1]
324
+
325
+ ref_3D_bundle_image = ref_3D_bundle_image.clip(0., 1.)
326
 
327
  if save_intermediate_results:
328
  save_path = os.path.join(TMP_DIR, f'{self.uuid}_ref_3d_bundle_image.png')
 
344
  control_guidance_end=None,
345
  controlnet_conditioning_scale=None,
346
  lora_scale=1.0,
347
+ num_inference_steps=None,
348
+ seed=None,
349
+ redux_hparam=None,
350
  save_intermediate_results=True,
351
  **kwargs):
352
  control_mode_dict = {
 
360
  } # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
361
 
362
  flux_device = self.config['flux'].get('device', 'cpu')
363
+ seed = seed or self.config['flux'].get('seed', 0)
364
+ num_inference_steps = num_inference_steps or self.config['flux'].get('num_inference_steps', 20)
365
 
366
  generator = torch.Generator(device=flux_device).manual_seed(seed)
367
 
368
+ if image is None:
369
+ image = torch.zeros((1, 3, 1024, 2048), dtype=torch.float32, device=flux_device)
370
+
371
  hparam_dict = {
372
+ 'prompt': 'A grid of 2x4 multi-view image, elevation 5. White background.',
373
+ 'prompt_2': ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', prompt]),
374
+ 'image': image,
375
  'strength': strength,
376
+ 'num_inference_steps': num_inference_steps,
377
  'guidance_scale': 3.5,
378
  'num_images_per_prompt': 1,
379
  'width': 2048,
 
383
  'joint_attention_kwargs': {"scale": lora_scale}
384
  }
385
  hparam_dict.update(kwargs)
386
+
387
+ # do redux
388
+ if redux_hparam is not None:
389
+ assert self.flux_redux_pipeline is not None
390
+ assert 'image' in redux_hparam.keys()
391
+ redux_hparam_ = {
392
+ 'prompt': hparam_dict.pop('prompt'),
393
+ 'prompt_2': hparam_dict.pop('prompt_2'),
394
+ }
395
+ redux_hparam_.update(redux_hparam)
396
+
397
+ with self.context():
398
+ redux_output = self.flux_redux_pipeline(**redux_hparam_)
399
+
400
+ hparam_dict.update(redux_output)
401
+
402
  # append controlnet hparams
403
  if len(control_image) > 0:
404
  assert isinstance(self.flux_pipeline, FluxControlNetImg2ImgPipeline)
405
  assert len(control_mode) == len(control_image) # the count of image should be the same as control mode
406
 
407
  flux_ctrl_net = self.flux_pipeline.controlnet.nets[0]
408
+ self.flux_pipeline.controlnet = FluxMultiControlNetModel([flux_ctrl_net for _ in control_mode])
409
 
410
  ctrl_hparams = {
411
  'control_mode': [control_mode_dict[mode_] for mode_ in control_mode],
 
430
 
431
  return gen_3d_bundle_image_
432
 
433
+ def preprocess_controlnet_cond_image(self, image, control_mode, save_intermediate_results=True, **kwargs):
434
+ """
435
+ image: Tensor of shape (c, h, w), range [0., 1.]
436
+ """
437
+ if control_mode in ['tile', 'lq']:
438
+ _, h, w = image.shape
439
+ down_scale = kwargs.get('down_scale', 4)
440
+ down_up = torchvision.transforms.Compose([
441
+ torchvision.transforms.Resize((h // down_scale, w // down_scale), interpolation=2), # 1 for lanczos and 2 for bilinear
442
+ torchvision.transforms.Resize((h, w), interpolation=2),
443
+ torchvision.transforms.ToPILImage()
444
+ ])
445
+ preprocessed = down_up(image)
446
+ elif control_mode == 'blur':
447
+ kernel_size = kwargs.get('kernel_size', 51)
448
+ sigma = kwargs.get('sigma', 2.0)
449
+ blur = torchvision.transforms.Compose([
450
+ torchvision.transforms.ToPILImage(),
451
+ torchvision.transforms.GaussianBlur(kernel_size, sigma),
452
+ ])
453
+ preprocessed = blur(image)
454
+ else:
455
+ raise NotImplementedError(f'Unexpected control mode {control_mode}')
456
+
457
+ if save_intermediate_results:
458
+ save_path = os.path.join(TMP_DIR, f'{self.uuid}_{control_mode}_controlnet_cond.png')
459
+ preprocessed.save(save_path)
460
+ logger.info(f'Save image to {save_path}')
461
+
462
+ return preprocessed
463
 
464
  def generate_3d_bundle_image_text(self,
465
  prompt,
466
  image=None,
467
  strength=1.0,
468
  lora_scale=1.0,
469
+ num_inference_steps=None,
470
+ seed=None,
471
+ redux_hparam=None,
472
  save_intermediate_results=True,
473
  **kwargs):
474
 
 
476
  return: gen_3d_bundle_image, torch.Tensor of shape (3, 1024, 2048), range [0., 1.]
477
  """
478
 
479
+ if isinstance(self.flux_pipeline, FluxImg2ImgPipeline):
 
 
 
 
 
 
 
 
 
 
480
  flux_pipeline = self.flux_pipeline
481
+ else:
482
+ flux_pipeline = convert_flux_pipeline(self.flux_pipeline, FluxImg2ImgPipeline)
483
 
484
  flux_device = self.config['flux'].get('device', 'cpu')
485
+ seed = seed or self.config['flux'].get('seed', 0)
486
+ num_inference_steps = num_inference_steps or self.config['flux'].get('num_inference_steps', 20)
487
+
488
+ if image is None:
489
+ image = torch.zeros((1, 3, 1024, 2048), dtype=torch.float32, device=flux_device)
490
 
491
  generator = torch.Generator(device=flux_device).manual_seed(seed)
492
 
493
+
494
  hparam_dict = {
495
+ 'prompt': 'A grid of 2x4 multi-view image, elevation 5. White background.',
496
+ 'prompt_2': ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', prompt]),
497
+ 'image': image,
498
  'strength': strength,
499
  'num_inference_steps': num_inference_steps,
500
  'guidance_scale': 3.5,
 
507
  }
508
  hparam_dict.update(kwargs)
509
 
510
+ # do redux
511
+ if redux_hparam is not None:
512
+ assert self.flux_redux_pipeline is not None
513
+ assert 'image' in redux_hparam.keys()
514
+ redux_hparam_ = {
515
+ 'prompt': hparam_dict.pop('prompt'),
516
+ 'prompt_2': hparam_dict.pop('prompt_2'),
517
+ }
518
+ redux_hparam_.update(redux_hparam)
519
+
520
+ with self.context():
521
+ redux_output = self.flux_redux_pipeline(**redux_hparam_)
522
+
523
+ hparam_dict.update(redux_output)
524
+
525
+
526
  with self.context():
527
  gen_3d_bundle_image = flux_pipeline(**hparam_dict).images
528
 
 
536
 
537
  return gen_3d_bundle_image_
538
 
539
+ def reconstruct_3d_bundle_image(self,
540
+ image,
541
+ lrm_render_radius=4.15,
542
+ isomer_radius=4.5,
543
+ reconstruction_stage1_steps=0,
544
+ reconstruction_stage2_steps=20,
545
+ save_intermediate_results=True):
546
  """
547
  image: torch.Tensor, range [0., 1.], (3, 1024, 2048)
548
  """
 
552
  images = rearrange(image, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (3, 1024, 2048) -> (8, 3, 512, 512)
553
  rgb_multi_view, normal_multi_view = images.chunk(2, dim=0)
554
  multi_view_mask = get_background(normal_multi_view).to(recon_device)
555
+ print(f'shape images: {images.shape}')
556
+ # breakpoint()
557
  rgb_multi_view = rgb_multi_view.to(recon_device) * multi_view_mask + (1 - multi_view_mask)
558
 
559
  with self.context():
 
561
  lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config,
562
  rgb_multi_view.unsqueeze(0).to(recon_device), name=self.uuid,
563
  input_camera_type='kiss3d', render_3d_bundle_image=save_intermediate_results,
564
+ render_azimuths=[0, 90, 180, 270],
565
+ render_radius=lrm_render_radius)
566
 
567
  if save_intermediate_results:
568
  recon_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
569
+ torchvision.utils.save_image(recon_3D_bundle_image, os.path.join(TMP_DIR, f'{self.uuid}_lrm_recon_3d_bundle_image.png'))
570
 
571
  recon_mesh_path = os.path.join(TMP_DIR, f"{self.uuid}_isomer_recon_mesh.obj")
572
 
 
575
  multi_view_mask=multi_view_mask,
576
  vertices=vertices,
577
  faces=faces,
578
+ save_path=recon_mesh_path,
579
+ radius=isomer_radius,
580
+ reconstruction_stage1_steps=int(reconstruction_stage1_steps),
581
+ reconstruction_stage2_steps=int(reconstruction_stage2_steps)
582
+ )
583
 
584
 
585
  def run_text_to_3d(k3d_wrapper,
 
595
  if init_image_path is not None:
596
  init_image = Image.open(init_image_path)
597
 
598
+ # refine prompt
599
+ logger.info(f"Input prompt: \"{prompt}\"")
600
+
601
+ prompt = k3d_wrapper.get_detailed_prompt(prompt)
602
+
603
  gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_text(prompt,
604
+ image=init_image,
605
+ strength=1.0,
606
+ save_intermediate_results=True)
607
+
608
+ # recon from 3D Bundle image
609
+ recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, save_intermediate_results=False)
610
+
611
+ return gen_save_path, recon_mesh_path
612
+
613
+ def image2mesh_preprocess(k3d_wrapper, input_image_, seed, use_mv_rgb=True):
614
+ seed_everything(seed)
615
+
616
+ # Renew The uuid
617
+ k3d_wrapper.renew_uuid()
618
+
619
+ # FOR IMAGE TO 3D: generate reference 3D bundle image from a single input image
620
+ input_image__ = Image.open(input_image_) if isinstance(input_image_, str) else input_image_
621
+
622
+ input_image = preprocess_input_image(input_image__)
623
+ input_image_save_path = os.path.join(TMP_DIR, f'{k3d_wrapper.uuid}_input_image.png')
624
+ input_image.save(input_image_save_path)
625
+
626
+ reference_3d_bundle_image, reference_save_path = k3d_wrapper.generate_reference_3D_bundle_image_zero123(input_image, use_mv_rgb=use_mv_rgb)
627
+ caption = k3d_wrapper.get_image_caption(input_image)
628
+
629
+ return input_image_save_path, reference_save_path, caption
630
+
631
+ def image2mesh_main(k3d_wrapper, input_image, reference_3d_bundle_image, caption, seed, strength1=0.5, strength2=0.95, enable_redux=True, use_controlnet=True):
632
+ seed_everything(seed)
633
+
634
+ if enable_redux:
635
+ redux_hparam = {
636
+ 'image': k3d_wrapper.to_512_tensor(input_image).unsqueeze(0).clip(0., 1.),
637
+ 'prompt_embeds_scale': 1.0,
638
+ 'pooled_prompt_embeds_scale': 1.0,
639
+ 'strength': strength1
640
+ }
641
+ else:
642
+ redux_hparam = None
643
+
644
+ if use_controlnet:
645
+ # prepare controlnet condition
646
+ control_mode = ['tile']
647
+ control_image = [k3d_wrapper.preprocess_controlnet_cond_image(reference_3d_bundle_image, mode_, down_scale=1, kernel_size=51, sigma=2.0) for mode_ in control_mode]
648
+ control_guidance_start = [0.0]
649
+ control_guidance_end = [0.3]
650
+ controlnet_conditioning_scale = [0.3]
651
+
652
+
653
+ gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_controlnet(
654
+ prompt=caption,
655
+ image=reference_3d_bundle_image.unsqueeze(0),
656
+ strength=strength2,
657
+ control_image=control_image,
658
+ control_mode=control_mode,
659
+ control_guidance_start=control_guidance_start,
660
+ control_guidance_end=control_guidance_end,
661
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
662
+ lora_scale=1.0,
663
+ redux_hparam=redux_hparam
664
+ )
665
+ else:
666
+ gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_text(
667
+ prompt=caption,
668
+ image=reference_3d_bundle_image.unsqueeze(0),
669
+ strength=strength2,
670
+ lora_scale=1.0,
671
+ redux_hparam=redux_hparam
672
+ )
673
 
674
  # recon from 3D Bundle image
675
  recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, save_intermediate_results=False)
676
 
677
  return gen_save_path, recon_mesh_path
678
 
679
+
680
+ def run_image_to_3d(k3d_wrapper, input_image_path, enable_redux=True, use_mv_rgb=True, use_controlnet=True):
681
  # ======================================= Example of image to 3D generation ======================================
682
 
683
  # Renew The uuid
684
  k3d_wrapper.renew_uuid()
685
 
686
  # FOR IMAGE TO 3D: generate reference 3D bundle image from a single input image
687
+ input_image = preprocess_input_image(Image.open(input_image_path))
688
+ input_image.save(os.path.join(TMP_DIR, f'{k3d_wrapper.uuid}_input_image.png'))
689
+
690
+ reference_3d_bundle_image, reference_save_path = k3d_wrapper.generate_reference_3D_bundle_image_zero123(input_image, use_mv_rgb=use_mv_rgb)
691
  caption = k3d_wrapper.get_image_caption(input_image)
692
 
693
+ if enable_redux:
694
+ redux_hparam = {
695
+ 'image': k3d_wrapper.to_512_tensor(input_image).unsqueeze(0).clip(0., 1.),
696
+ 'prompt_embeds_scale': 1.0,
697
+ 'pooled_prompt_embeds_scale': 1.0,
698
+ 'strength': 0.5
699
+ }
700
+ else:
701
+ redux_hparam = None
702
+
703
+ if use_controlnet:
704
+ # prepare controlnet condition
705
+ control_mode = ['tile']
706
+ control_image = [k3d_wrapper.preprocess_controlnet_cond_image(reference_3d_bundle_image, mode_, down_scale=1, kernel_size=51, sigma=2.0) for mode_ in control_mode]
707
+ control_guidance_start = [0.0]
708
+ control_guidance_end = [0.3]
709
+ controlnet_conditioning_scale = [0.3]
710
+
711
+
712
+ gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_controlnet(
713
+ prompt=caption,
714
+ image=reference_3d_bundle_image.unsqueeze(0),
715
+ strength=.95,
716
+ control_image=control_image,
717
+ control_mode=control_mode,
718
+ control_guidance_start=control_guidance_start,
719
+ control_guidance_end=control_guidance_end,
720
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
721
+ lora_scale=1.0,
722
+ redux_hparam=redux_hparam
723
+ )
724
+ else:
725
+ gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_text(
726
+ prompt=caption,
727
+ image=reference_3d_bundle_image.unsqueeze(0),
728
+ strength=.95,
729
+ lora_scale=1.0,
730
+ redux_hparam=redux_hparam
731
+ )
732
+
733
+ # recon from 3D Bundle image
734
+ recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, save_intermediate_results=False)
735
 
736
+ return gen_save_path, recon_mesh_path
 
737
 
738
 
739
  if __name__ == "__main__":
740
+ k3d_wrapper = init_wrapper_from_config('/hpc2hdd/home/jlin695/code/github/Kiss3DGen/pipeline/pipeline_config/default.yaml')
741
+
742
+ os.system(f'rm -rf {TMP_DIR}/*')
743
+ # os.system(f'rm -rf {OUT_DIR}/3d_bundle/*')
744
 
745
+ enable_redux = True
746
+ use_mv_rgb = True
747
+ use_controlnet = True
748
 
749
+ img_folder = '/hpc2hdd/home/jlin695/code/Kiss3DGen/examples'
750
+ for img_ in os.listdir(img_folder):
751
+ name, _ = os.path.splitext(img_)
752
+ print("Now processing:", name)
753
+
754
+ gen_save_path, recon_mesh_path = run_image_to_3d(k3d_wrapper, os.path.join(img_folder, img_), enable_redux, use_mv_rgb, use_controlnet)
755
+
756
+ os.system(f'cp -f {gen_save_path} {OUT_DIR}/3d_bundle/{name}_3d_bundle.png')
757
+ os.system(f'cp -f {recon_mesh_path} {OUT_DIR}/3d_bundle/{name}.obj')
758
+
759
+ # TODO exams:
760
+ # 1. redux True, mv_rgb False, Tile, down_scale = 1
761
+ # 2. redux False, mv_rgb True, Tile, down_scale = 8
762
+ # 3. redux False, mv_rgb False, Tile, blur = 10
763
+
764
+
765
  # run_text_to_3d(k3d_wrapper, prompt='A doll of a girl in Harry Potter')
766
 
767
+
768
+ # Example of loading existing 3D bundle Image as Tensor from path
769
+ # pseudo_image = Image.open('/hpc2hdd/home/jlin695/code/github/Kiss3DGen/outputs/tmp/fbf6edad-2d7f-49e5-8ac2-a05af5fe695b_ref_3d_bundle_image.png')
770
+ # gen_3d_bundle_image = torchvision.transforms.functional.to_tensor(pseudo_image)
pipeline/pipeline_config/default.yaml CHANGED
@@ -1,15 +1,19 @@
1
  flux:
2
- base_model: "/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev"
3
- lora: "./checkpoint/flux_lora/rgb_normal_doll_object.safetensors"
4
- controlnet: "/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/flux_controlnets/FLUX.1-dev-ControlNet-Union-Pro"
5
- seed: 0
 
 
 
6
  device: 'cuda:0'
7
 
8
  multiview:
9
  base_model: "sudo-ai/zero123plus-v1.2"
10
  custom_pipeline: "./models/zero123plus"
11
  unet: "./checkpoint/zero123++/flexgen_19w.ckpt"
12
- num_inference_steps: 75
 
13
  device: 'cuda:0'
14
 
15
  reconstruction:
@@ -18,8 +22,12 @@ reconstruction:
18
  device: 'cuda:0'
19
 
20
  caption:
21
- base_model: "/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2"
22
- device: 'cuda:0'
 
 
 
 
23
 
24
  use_zero_gpu: false # for huggingface demo only
25
- 3d_bundle_templates: '/hpc2hdd/home/jlin695/code/github/Kiss3DGen/init_3d_Bundle'
 
1
  flux:
2
+ base_model: "https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"
3
+ flux_dtype: 'fp8'
4
+ lora: "./checkpoint/flux_lora/rgb_normal_large.safetensors"
5
+ controlnet: "InstantX/FLUX.1-dev-Controlnet-Union"
6
+ redux: "black-forest-labs/FLUX.1-Redux-dev"
7
+ num_inference_steps: 20
8
+ seed: 42
9
  device: 'cuda:0'
10
 
11
  multiview:
12
  base_model: "sudo-ai/zero123plus-v1.2"
13
  custom_pipeline: "./models/zero123plus"
14
  unet: "./checkpoint/zero123++/flexgen_19w.ckpt"
15
+ num_inference_steps: 50
16
+ seed: 42
17
  device: 'cuda:0'
18
 
19
  reconstruction:
 
22
  device: 'cuda:0'
23
 
24
  caption:
25
+ base_model: "multimodalart/Florence-2-large-no-flash-attn"
26
+ device: 'cuda:1'
27
+
28
+ llm:
29
+ base_model: "Qwen/Qwen2-7B-Instruct"
30
+ device: 'cuda:1'
31
 
32
  use_zero_gpu: false # for huggingface demo only
33
+ 3d_bundle_templates: './init_3d_Bundle'
pipeline/run_hpc.sh CHANGED
@@ -5,6 +5,6 @@ export CC=$(which gcc)
5
  export CPLUS_INCLUDE_PATH=/hpc2ssd/softwares/cuda/cuda-12.1/targets/x86_64-linux/include:$CPLUS_INCLUDE_PATH
6
  export CUDA_LAUNCH_BLOCKING=1
7
  export NCCL_TIMEOUT=3600
8
- export CUDA_VISIBLE_DEVICES="0"
9
 
10
  python ./pipeline/kiss3d_wrapper.py
 
5
  export CPLUS_INCLUDE_PATH=/hpc2ssd/softwares/cuda/cuda-12.1/targets/x86_64-linux/include:$CPLUS_INCLUDE_PATH
6
  export CUDA_LAUNCH_BLOCKING=1
7
  export NCCL_TIMEOUT=3600
8
+ export CUDA_VISIBLE_DEVICES="0,1"
9
 
10
  python ./pipeline/kiss3d_wrapper.py
run_hpc.sh → pipeline/run_hpc_text_to_3d.sh RENAMED
@@ -5,7 +5,6 @@ export CC=$(which gcc)
5
  export CPLUS_INCLUDE_PATH=/hpc2ssd/softwares/cuda/cuda-12.1/targets/x86_64-linux/include:$CPLUS_INCLUDE_PATH
6
  export CUDA_LAUNCH_BLOCKING=1
7
  export NCCL_TIMEOUT=3600
8
- export CUDA_VISIBLE_DEVICES="0"
9
- # python app.py
10
- python text_to_mesh.py
11
- # python image_to_mesh.py
 
5
  export CPLUS_INCLUDE_PATH=/hpc2ssd/softwares/cuda/cuda-12.1/targets/x86_64-linux/include:$CPLUS_INCLUDE_PATH
6
  export CUDA_LAUNCH_BLOCKING=1
7
  export NCCL_TIMEOUT=3600
8
+ export CUDA_VISIBLE_DEVICES="0,1"
9
+
10
+ python ./pipeline/example_text_to_3d.py
 
pipeline/utils.py CHANGED
@@ -10,18 +10,20 @@ print(__workdir__)
10
  import numpy as np
11
  import torch
12
  from torchvision.transforms import v2
 
 
13
 
14
  from models.lrm.online_render.render_single import load_mipmap
15
  from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
16
  from models.lrm.utils.render_utils import rotate_x, rotate_y
17
  from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
 
18
 
19
  from models.ISOMER.reconstruction_func import reconstruction
20
  from models.ISOMER.projection_func import projection
21
 
22
  from utils.tool import NormalTransfer, get_render_cameras_frames, get_background, get_render_cameras_video, render_frames, mask_fix
23
 
24
-
25
  logging.basicConfig(
26
  level = logging.INFO
27
  )
@@ -38,7 +40,7 @@ def lrm_reconstruct(model, infer_config, images,
38
  render_3d_bundle_image=True,
39
  render_azimuths=[270, 0, 90, 180],
40
  render_elevations=[5, 5, 5, 5],
41
- render_radius=4.5):
42
  """
43
  image: Tensor, shape (1, c, h, w)
44
  """
@@ -49,7 +51,7 @@ def lrm_reconstruct(model, infer_config, images,
49
  if input_camera_type == 'zero123':
50
  input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
51
  elif input_camera_type == 'kiss3d':
52
- input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device)
53
  else:
54
  raise NotImplementedError(f'Unexpected input camera type: {input_camera_type}')
55
 
@@ -142,9 +144,9 @@ def isomer_reconstruct(
142
  elevations=[5, 5, 5, 5],
143
  geo_weights=[1, 0.9, 1, 0.9],
144
  color_weights=[1, 0.5, 1, 0.5],
145
- reconstruction_stage1_steps=50,
146
  reconstruction_stage2_steps=50,
147
- radius=4.1):
148
 
149
  device = rgb_multi_view.device
150
  to_tensor_ = lambda x: torch.Tensor(x).float().to(device)
@@ -180,6 +182,7 @@ def isomer_reconstruct(
180
 
181
  multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-10, blur=5)
182
 
 
183
  logger.info(f"==> Runing ISOMER projection ...")
184
  save_glb_addr = projection(
185
  meshes,
@@ -195,4 +198,29 @@ def isomer_reconstruct(
195
  )
196
 
197
  logger.info(f"==> Save mesh to {save_glb_addr} ...")
198
- return save_glb_addr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import numpy as np
11
  import torch
12
  from torchvision.transforms import v2
13
+ from PIL import Image
14
+ import rembg
15
 
16
  from models.lrm.online_render.render_single import load_mipmap
17
  from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
18
  from models.lrm.utils.render_utils import rotate_x, rotate_y
19
  from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
20
+ from models.lrm.utils.infer_util import remove_background, resize_foreground
21
 
22
  from models.ISOMER.reconstruction_func import reconstruction
23
  from models.ISOMER.projection_func import projection
24
 
25
  from utils.tool import NormalTransfer, get_render_cameras_frames, get_background, get_render_cameras_video, render_frames, mask_fix
26
 
 
27
  logging.basicConfig(
28
  level = logging.INFO
29
  )
 
40
  render_3d_bundle_image=True,
41
  render_azimuths=[270, 0, 90, 180],
42
  render_elevations=[5, 5, 5, 5],
43
+ render_radius=4.15):
44
  """
45
  image: Tensor, shape (1, c, h, w)
46
  """
 
51
  if input_camera_type == 'zero123':
52
  input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
53
  elif input_camera_type == 'kiss3d':
54
+ input_cameras = get_flux_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
55
  else:
56
  raise NotImplementedError(f'Unexpected input camera type: {input_camera_type}')
57
 
 
144
  elevations=[5, 5, 5, 5],
145
  geo_weights=[1, 0.9, 1, 0.9],
146
  color_weights=[1, 0.5, 1, 0.5],
147
+ reconstruction_stage1_steps=10,
148
  reconstruction_stage2_steps=50,
149
+ radius=4.5):
150
 
151
  device = rgb_multi_view.device
152
  to_tensor_ = lambda x: torch.Tensor(x).float().to(device)
 
182
 
183
  multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-10, blur=5)
184
 
185
+
186
  logger.info(f"==> Runing ISOMER projection ...")
187
  save_glb_addr = projection(
188
  meshes,
 
198
  )
199
 
200
  logger.info(f"==> Save mesh to {save_glb_addr} ...")
201
+ return save_glb_addr
202
+
203
+
204
+ def to_rgb_image(maybe_rgba):
205
+ assert isinstance(maybe_rgba, Image.Image)
206
+ if maybe_rgba.mode == 'RGB':
207
+ return maybe_rgba, None
208
+ elif maybe_rgba.mode == 'RGBA':
209
+ rgba = maybe_rgba
210
+ img = np.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8)
211
+ img = Image.fromarray(img, 'RGB')
212
+ img.paste(rgba, mask=rgba.getchannel('A'))
213
+ return img, rgba.getchannel('A')
214
+ else:
215
+ raise ValueError("Unsupported image type.", maybe_rgba.mode)
216
+
217
+ rembg_session = rembg.new_session("u2net")
218
+ def preprocess_input_image(input_image):
219
+ """
220
+ input_image: PIL.Image
221
+ output_image: PIL.Image, (3, 512, 512), mode = RGB, background = white
222
+ """
223
+ image = remove_background(to_rgb_image(input_image)[0], rembg_session, bgcolor=(255, 255, 255, 255))
224
+ image = resize_foreground(image, ratio=0.85, pad_value=255)
225
+ return to_rgb_image(image)[0]
226
+
run.sh DELETED
@@ -1,2 +0,0 @@
1
- export CUDA_VISIBLE_DEVICES="0"
2
- python text_to_mesh.py
 
 
 
text_to_mesh.py DELETED
@@ -1,232 +0,0 @@
1
- import os
2
- from einops import rearrange
3
- from omegaconf import OmegaConf
4
- import torch
5
- import numpy as np
6
- import trimesh
7
- import torchvision
8
- import torch.nn.functional as F
9
- from PIL import Image
10
- from torchvision import transforms
11
- from torchvision.transforms import v2
12
- from diffusers import HeunDiscreteScheduler
13
- from diffusers import FluxPipeline
14
- from pytorch_lightning import seed_everything
15
- import os
16
- import time
17
- from models.lrm.utils.infer_util import save_video
18
- from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
19
- from models.lrm.utils.render_utils import rotate_x, rotate_y
20
- from models.lrm.utils.train_util import instantiate_from_config
21
- from models.lrm.utils.camera_util import get_flux_input_cameras
22
- from models.ISOMER.reconstruction_func import reconstruction
23
- from models.ISOMER.projection_func import projection
24
- from utils.tool import NormalTransfer, load_mipmap
25
- from utils.tool import get_background, get_render_cameras_video, render_frames
26
-
27
- device = "cuda"
28
- resolution = 512
29
- save_dir = "./outputs"
30
- normal_transfer = NormalTransfer()
31
- isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device)
32
- isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device)
33
- isomer_radius = 4.5
34
- isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
35
- isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
36
-
37
- # model initialization and loading
38
- # flux
39
- flux_pipe = FluxPipeline.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev", torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
40
- flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors')
41
-
42
- flux_pipe.to(device=device, dtype=torch.bfloat16)
43
- generator = torch.Generator(device=device).manual_seed(10)
44
-
45
- # lrm
46
- config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
47
- model_config = config.model_config
48
- infer_config = config.infer_config
49
- model = instantiate_from_config(model_config)
50
- model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt"
51
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
52
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
53
- model.load_state_dict(state_dict, strict=True)
54
-
55
- model = model.to(device)
56
- model.init_flexicubes_geometry(device, fovy=50.0)
57
- model = model.eval()
58
-
59
- # Flux multi-view generation
60
- def multi_view_rgb_normal_generation(prompt, save_path=None):
61
- # generate multi-view images
62
- with torch.no_grad():
63
- image = flux_pipe(
64
- prompt=prompt,
65
- num_inference_steps=30,
66
- guidance_scale=3.5,
67
- num_images_per_prompt=1,
68
- width=resolution*4,
69
- height=resolution*2,
70
- output_type='np',
71
- generator=generator
72
- ).images
73
- return image
74
-
75
- # lrm reconstructions
76
- def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
77
- images = image.unsqueeze(0).to(device)
78
- images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
79
- # breakpoint()
80
- with torch.no_grad():
81
- # get triplane
82
- planes = model.forward_planes(images, input_cameras)
83
-
84
- mesh_path_idx = os.path.join(save_path, f'{name}.obj')
85
-
86
- mesh_out = model.extract_mesh(
87
- planes,
88
- use_texture_map=export_texmap,
89
- **infer_config,
90
- )
91
- if export_texmap:
92
- vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
93
- save_obj_with_mtl(
94
- vertices.data.cpu().numpy(),
95
- uvs.data.cpu().numpy(),
96
- faces.data.cpu().numpy(),
97
- mesh_tex_idx.data.cpu().numpy(),
98
- tex_map.permute(1, 2, 0).data.cpu().numpy(),
99
- mesh_path_idx,
100
- )
101
- else:
102
- vertices, faces, vertex_colors = mesh_out
103
- save_obj(vertices, faces, vertex_colors, mesh_path_idx)
104
- print(f"Mesh saved to {mesh_path_idx}")
105
-
106
- render_size = 512
107
- if if_save_video:
108
- video_path_idx = os.path.join(save_path, f'{name}.mp4')
109
- render_size = infer_config.render_resolution
110
- ENV = load_mipmap("models/lrm/env_mipmap/6")
111
- materials = (0.0,0.9)
112
-
113
- all_mv, all_mvp, all_campos = get_render_cameras_video(
114
- batch_size=1,
115
- M=240,
116
- radius=4.5,
117
- elevation=(90, 60.0),
118
- is_flexicubes=True,
119
- fov=30
120
- )
121
-
122
- frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
123
- model,
124
- planes,
125
- render_cameras=all_mvp,
126
- camera_pos=all_campos,
127
- env=ENV,
128
- materials=materials,
129
- render_size=render_size,
130
- chunk_size=20,
131
- is_flexicubes=True,
132
- )
133
- normals = (torch.nn.functional.normalize(normals) + 1) / 2
134
- normals = normals * alphas + (1-alphas)
135
- all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
136
-
137
- save_video(
138
- all_frames,
139
- video_path_idx,
140
- fps=30,
141
- )
142
- print(f"Video saved to {video_path_idx}")
143
-
144
- return vertices, faces
145
-
146
-
147
- def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
148
- if local_normal_images.min() >= 0:
149
- local_normal = local_normal_images.float() * 2 - 1
150
- else:
151
- local_normal = local_normal_images.float()
152
- global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
153
- global_normal[...,0] *= -1
154
- global_normal = (global_normal + 1) / 2
155
- global_normal = global_normal.permute(0, 3, 1, 2)
156
- return global_normal
157
-
158
- def main():
159
- end = time.time()
160
- fix_prompt = 'a grid of 2x4 multi-view image. elevation 5. white background.'
161
- # user prompt
162
- prompt = "a owl wearing a hat."
163
- save_dir_path = os.path.join(save_dir, prompt.split(".")[0].replace(" ", "_"))
164
- os.makedirs(save_dir_path, exist_ok=True)
165
- prompt = fix_prompt+" "+prompt
166
- # generate multi-view images
167
- rgb_normal_grid = multi_view_rgb_normal_generation(prompt)
168
- # lrm reconstructions
169
- images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
170
- images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
171
- rgb_multi_view = images[:4, :3, :, :]
172
- normal_multi_view = images[4:, :3, :, :]
173
- multi_view_mask = get_background(normal_multi_view)
174
- rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
175
- input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device)
176
- vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False)
177
- # local normal to global normal
178
-
179
- global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
180
- global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
181
-
182
- global_normal = global_normal.permute(0,2,3,1)
183
- rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
184
- multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
185
- vertices = torch.from_numpy(vertices).to(device)
186
- faces = torch.from_numpy(faces).to(device)
187
- vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
188
- vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
189
-
190
- # global_normal: B,H,W,3
191
- # multi_view_mask: B,H,W
192
- # rgb_multi_view: B,H,W,3
193
-
194
- meshes = reconstruction(
195
- normal_pils=global_normal,
196
- masks=multi_view_mask,
197
- weights=isomer_geo_weights,
198
- fov=30,
199
- radius=isomer_radius,
200
- camera_angles_azi=isomer_azimuths,
201
- camera_angles_ele=isomer_elevations,
202
- expansion_weight_stage1=0.1,
203
- init_type="file",
204
- init_verts=vertices,
205
- init_faces=faces,
206
- stage1_steps=0,
207
- stage2_steps=50,
208
- start_edge_len_stage1=0.1,
209
- end_edge_len_stage1=0.02,
210
- start_edge_len_stage2=0.02,
211
- end_edge_len_stage2=0.005,
212
- )
213
-
214
-
215
- save_glb_addr = projection(
216
- meshes,
217
- masks=multi_view_mask,
218
- images=rgb_multi_view,
219
- azimuths=isomer_azimuths,
220
- elevations=isomer_elevations,
221
- weights=isomer_color_weights,
222
- fov=30,
223
- radius=isomer_radius,
224
- save_dir=f"{save_dir_path}/ISOMER/",
225
- )
226
- print(f'saved to {save_glb_addr}')
227
- print(f"Time elapsed: {time.time() - end:.2f}s")
228
-
229
-
230
-
231
- if __name__ == '__main__':
232
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
text_to_mesh_new.py DELETED
@@ -1,244 +0,0 @@
1
- import os
2
- from einops import rearrange
3
- from omegaconf import OmegaConf
4
- import torch
5
- import numpy as np
6
- import trimesh
7
- import torchvision
8
- import torch.nn.functional as F
9
- from PIL import Image
10
- from torchvision import transforms
11
- from torchvision.transforms import v2
12
- from diffusers import HeunDiscreteScheduler
13
- from diffusers import FluxPipeline
14
- from pytorch_lightning import seed_everything
15
- import os
16
-
17
- import time
18
-
19
- from models.lrm.utils.infer_util import save_video
20
- from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
21
- from models.lrm.utils.render_utils import rotate_x, rotate_y
22
- from models.lrm.utils.train_util import instantiate_from_config
23
- from models.lrm.utils.camera_util import get_flux_input_cameras
24
- from models.ISOMER.reconstruction_func import reconstruction
25
- from models.ISOMER.projection_func import projection
26
- from utils.tool import NormalTransfer, load_mipmap
27
- from utils.tool import get_background, get_render_cameras_video, render_frames, mask_fix
28
-
29
- device = "cuda"
30
- resolution = 512
31
- save_dir = "./outputs/text2"
32
- normal_transfer = NormalTransfer()
33
- isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device)
34
- isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device)
35
- isomer_radius = 4.5
36
- isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
37
- isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
38
-
39
- # model initialization and loading
40
- # flux
41
- flux_pipe = FluxPipeline.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev", torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
42
- flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors')
43
-
44
- flux_pipe.to(device=device, dtype=torch.bfloat16)
45
- generator = torch.Generator(device=device).manual_seed(10)
46
-
47
- # lrm
48
- config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
49
- model_config = config.model_config
50
- infer_config = config.infer_config
51
- model = instantiate_from_config(model_config)
52
- model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt"
53
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
54
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
55
- model.load_state_dict(state_dict, strict=True)
56
-
57
- model = model.to(device)
58
- model.init_flexicubes_geometry(device, fovy=50.0)
59
- model = model.eval()
60
-
61
- # Flux multi-view generation
62
- def multi_view_rgb_normal_generation(prompt, save_path=None):
63
- # generate multi-view images
64
- with torch.no_grad():
65
- image = flux_pipe(
66
- prompt=prompt,
67
- num_inference_steps=30,
68
- guidance_scale=3.5,
69
- num_images_per_prompt=1,
70
- width=resolution*4,
71
- height=resolution*2,
72
- output_type='np',
73
- generator=generator
74
- ).images
75
- return image
76
-
77
- # lrm reconstructions
78
- def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
79
- images = image.unsqueeze(0).to(device)
80
- images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
81
- # breakpoint()
82
- with torch.no_grad():
83
- # get triplane
84
- planes = model.forward_planes(images, input_cameras)
85
-
86
- mesh_path_idx = os.path.join(save_path, f'{name}.obj')
87
-
88
- mesh_out = model.extract_mesh(
89
- planes,
90
- use_texture_map=export_texmap,
91
- **infer_config,
92
- )
93
- if export_texmap:
94
- vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
95
- save_obj_with_mtl(
96
- vertices.data.cpu().numpy(),
97
- uvs.data.cpu().numpy(),
98
- faces.data.cpu().numpy(),
99
- mesh_tex_idx.data.cpu().numpy(),
100
- tex_map.permute(1, 2, 0).data.cpu().numpy(),
101
- mesh_path_idx,
102
- )
103
- else:
104
- vertices, faces, vertex_colors = mesh_out
105
- save_obj(vertices, faces, vertex_colors, mesh_path_idx)
106
- print(f"Mesh saved to {mesh_path_idx}")
107
-
108
- render_size = 512
109
- if if_save_video:
110
- video_path_idx = os.path.join(save_path, f'{name}.mp4')
111
- render_size = infer_config.render_resolution
112
- ENV = load_mipmap("models/lrm/env_mipmap/6")
113
- materials = (0.0,0.9)
114
-
115
- all_mv, all_mvp, all_campos = get_render_cameras_video(
116
- batch_size=1,
117
- M=240,
118
- radius=4.5,
119
- elevation=(90, 60.0),
120
- is_flexicubes=True,
121
- fov=30
122
- )
123
-
124
- frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
125
- model,
126
- planes,
127
- render_cameras=all_mvp,
128
- camera_pos=all_campos,
129
- env=ENV,
130
- materials=materials,
131
- render_size=render_size,
132
- chunk_size=20,
133
- is_flexicubes=True,
134
- )
135
- normals = (torch.nn.functional.normalize(normals) + 1) / 2
136
- normals = normals * alphas + (1-alphas)
137
- all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
138
-
139
- save_video(
140
- all_frames,
141
- video_path_idx,
142
- fps=30,
143
- )
144
- print(f"Video saved to {video_path_idx}")
145
-
146
- return vertices, faces
147
-
148
-
149
- def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
150
- if local_normal_images.min() >= 0:
151
- local_normal = local_normal_images.float() * 2 - 1
152
- else:
153
- local_normal = local_normal_images.float()
154
- global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
155
- global_normal[...,0] *= -1
156
- global_normal = (global_normal + 1) / 2
157
- global_normal = global_normal.permute(0, 3, 1, 2)
158
- return global_normal
159
-
160
- def main(prompt = "a owl wearing a hat."):
161
- fix_prompt = 'a grid of 2x4 multi-view image. elevation 5. white background.'
162
- # user prompt
163
-
164
- save_dir_path = os.path.join(save_dir, prompt.split(".")[0].replace(" ", "_"))
165
- os.makedirs(save_dir_path, exist_ok=True)
166
- prompt = fix_prompt+" "+prompt
167
- # generate multi-view images
168
- rgb_normal_grid = multi_view_rgb_normal_generation(prompt)
169
- # lrm reconstructions
170
- images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
171
- images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
172
- rgb_multi_view = images[:4, :3, :, :]
173
- normal_multi_view = images[4:, :3, :, :]
174
- multi_view_mask = get_background(normal_multi_view)
175
- rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
176
- input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device)
177
- vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False)
178
- # local normal to global normal
179
-
180
- global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
181
- global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
182
-
183
- global_normal = global_normal.permute(0,2,3,1)
184
- rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
185
- multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
186
- vertices = torch.from_numpy(vertices).to(device)
187
- faces = torch.from_numpy(faces).to(device)
188
- vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
189
- vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
190
-
191
- # global_normal: B,H,W,3
192
- # multi_view_mask: B,H,W
193
- # rgb_multi_view: B,H,W,3
194
-
195
- multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-6, blur=5)
196
-
197
- meshes = reconstruction(
198
- normal_pils=global_normal,
199
- masks=multi_view_mask,
200
- weights=isomer_geo_weights,
201
- fov=30,
202
- radius=isomer_radius,
203
- camera_angles_azi=isomer_azimuths,
204
- camera_angles_ele=isomer_elevations,
205
- expansion_weight_stage1=0.1,
206
- init_type="file",
207
- init_verts=vertices,
208
- init_faces=faces,
209
- stage1_steps=0,
210
- stage2_steps=50,
211
- start_edge_len_stage1=0.1,
212
- end_edge_len_stage1=0.02,
213
- start_edge_len_stage2=0.02,
214
- end_edge_len_stage2=0.005,
215
- )
216
-
217
-
218
- multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-10, blur=5)
219
-
220
- save_glb_addr = projection(
221
- meshes,
222
- masks=multi_view_mask_proj,
223
- images=rgb_multi_view,
224
- azimuths=isomer_azimuths,
225
- elevations=isomer_elevations,
226
- weights=isomer_color_weights,
227
- fov=30,
228
- radius=isomer_radius,
229
- save_dir=f"{save_dir_path}/ISOMER/",
230
- )
231
- print(f'saved to {save_glb_addr}')
232
-
233
-
234
-
235
- if __name__ == '__main__':
236
- import time
237
- start_time = time.time()
238
- prompts = ["A red dragon soaring", "A running Chihuahua", "A dancing rabbit", "A girl with blue hair and white dress", "A teacher", "A tiger playing guitar", "A red rose", "A red peony", "A rose in a vase", "A golden retriever sitting", "A golden retriever running"]
239
- for prompt in prompts:
240
- main(prompt)
241
- end_time = time.time()
242
- print(f"Time taken: {end_time - start_time:.2f} seconds for {len(prompts)} prompts")
243
-
244
- breakpoint()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
upload_huggingface.py DELETED
@@ -1,57 +0,0 @@
1
- from huggingface_hub import HfApi, HfFolder, Repository, create_repo, upload_file
2
- import os
3
-
4
- # 登录到 Hugging Face
5
- from huggingface_hub import login
6
- login()
7
-
8
- # 创建或指定现有的 Repository
9
- repo_name = "xxx-ckpt"
10
- username = "LTT"
11
- repo_id = f"{username}/{repo_name}"
12
-
13
- # 创建仓库(如果它不存在)
14
- create_repo(repo_id, exist_ok=True)
15
-
16
- # 文件夹
17
- # 上传整个文件夹
18
- def upload_folder(folder_path, repo_id):
19
- """
20
- 递归上传文件夹及其内容到 Hugging Face 仓库。
21
- """
22
- for root, _, files in os.walk(folder_path):
23
- for file in files:
24
- # 文件完整路径
25
- full_file_path = os.path.join(root, file)
26
- # 相对于文件夹的相对路径(保留文件夹结构)
27
- relative_path = os.path.relpath(full_file_path, folder_path)
28
-
29
- # 上传文件到仓库
30
- print(f"Uploading {relative_path}...")
31
- upload_file(
32
- path_or_fileobj=full_file_path,
33
- path_in_repo=relative_path,
34
- repo_id=repo_id
35
- )
36
- print(f"Uploaded {relative_path} successfully.")
37
-
38
-
39
- # 上传模型文件
40
- model_path = "checkpoint/zero123++/flexgen_19w.ckpt"
41
- upload_file(path_or_fileobj=model_path, path_in_repo="flexgen_19w.ckpt", repo_id=repo_id)
42
-
43
- # # 上传数据文件
44
- # data_path = "/hpc2hdd/home/jlin695/data/env_map/data/env_mipmap_large.tar.gz"
45
- # upload_file(path_or_fileobj=data_path, path_in_repo="env_mipmap_large.tar.gz", repo_id=repo_id)
46
-
47
- # # 上传数据文件
48
- # data_path = "/hpc2hdd/home/jlin695/data/env_map/data/env_map_light_large.tar.gz"
49
- # upload_file(path_or_fileobj=data_path, path_in_repo="env_map_light_large.tar.gz", repo_id=repo_id)
50
-
51
- # # 定义要上传的文件夹路径
52
- # folder_path = "checkpoint/flux_lora"
53
-
54
- # # 调用上传文件夹的函数
55
- # upload_folder(folder_path, repo_id)
56
-
57
- # print("模型和数据文件已上传到 Hugging Face。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_render.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ import imageio
5
+ import trimesh
6
+ import pyrender
7
+ from tqdm import tqdm
8
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "7"
9
+ os.environ['PYOPENGL_PLATFORM'] = 'egl' # 设置渲染环境为 EGL(无头模式)
10
+
11
+ def render_video_from_obj(input_obj_path, output_video_path, fps=15, frame_count=60, resolution=(512, 512)):
12
+ """
13
+ Render a rotating 3D model (OBJ file) to a video with RGB and normal map side-by-side.
14
+
15
+ Args:
16
+ input_obj_path (str): Path to the input OBJ file.
17
+ output_video_path (str): Path to save the output video.
18
+ fps (int): Frames per second for the video.
19
+ frame_count (int): Number of frames in the video.
20
+ resolution (tuple): Resolution of the rendered video (width, height).
21
+
22
+ Returns:
23
+ str: Path to the output video.
24
+ """
25
+ # 检查输入文件是否存在
26
+ if not os.path.exists(input_obj_path):
27
+ raise FileNotFoundError(f"Input OBJ file not found: {input_obj_path}")
28
+
29
+ # 加载3D模型
30
+ scene_data = trimesh.load(input_obj_path)
31
+
32
+ # 提取或合并网格
33
+ if isinstance(scene_data, trimesh.Scene):
34
+ mesh_data = trimesh.util.concatenate([geom for geom in scene_data.geometry.values()])
35
+ else:
36
+ mesh_data = scene_data
37
+
38
+ # 确保顶点法线存在
39
+ if not hasattr(mesh_data, 'vertex_normals') or mesh_data.vertex_normals is None:
40
+ mesh_data.compute_vertex_normals()
41
+
42
+ # 创建 Pyrender 场景并设置背景为白色
43
+ render_scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0])
44
+ mesh = pyrender.Mesh.from_trimesh(mesh_data, smooth=True)
45
+ mesh_node = render_scene.add(mesh)
46
+
47
+ # 设置摄像机参数
48
+ camera = pyrender.PerspectiveCamera(yfov=np.deg2rad(30), znear=0.0001, zfar=100000.0)
49
+ camera_pose = np.eye(4)
50
+ camera_pose[2, 3] = 4.0 # 距离模型 20 个单位
51
+ render_scene.add(camera, pose=camera_pose)
52
+
53
+ # 添加全局环境光
54
+ ambient_light = np.array([1.0, 1.0, 1.0]) * 2.0
55
+ render_scene.ambient_light = ambient_light
56
+
57
+ # 准备法线渲染场景
58
+ normals = mesh_data.vertex_normals.copy()
59
+
60
+ # 将法线映射到颜色范围 [0, 255]
61
+ normal_colors = ((normals + 1) / 2 * 255)
62
+
63
+ # 创建用于法线渲染的独立网格
64
+ normal_mesh_data = mesh_data.copy()
65
+ normal_mesh_data.visual.vertex_colors = np.hstack(
66
+ [normal_colors, np.full((normals.shape[0], 1), 255, dtype=np.uint8)] # 添加 Alpha 通道
67
+ )
68
+
69
+ # 创建法线渲染场景
70
+ normal_scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 1.0])
71
+ normal_mesh = pyrender.Mesh.from_trimesh(normal_mesh_data, smooth=True)
72
+ normal_mesh_node = normal_scene.add(normal_mesh)
73
+ normal_scene.add(camera, pose=camera_pose)
74
+ normal_scene.ambient_light = ambient_light
75
+
76
+ # 初始化渲染器
77
+ r = pyrender.OffscreenRenderer(*resolution)
78
+
79
+ # 创建视频写入器
80
+ writer = imageio.get_writer(output_video_path, fps=fps)
81
+
82
+ # 渲染每一帧
83
+ try:
84
+ for frame_idx in tqdm(range(frame_count)):
85
+ # 计算旋转角度
86
+ angle = 2 * np.pi * frame_idx / frame_count
87
+ rotation_matrix = np.array([
88
+ [math.cos(angle), 0, math.sin(angle), 0],
89
+ [0, 1, 0, 0],
90
+ [-math.sin(angle), 0, math.cos(angle), 0],
91
+ [0, 0, 0, 1]
92
+ ])
93
+
94
+ # 更新模型的姿态
95
+ render_scene.set_pose(mesh_node, rotation_matrix)
96
+
97
+ # 渲染 RGB 图像
98
+ color, _ = r.render(render_scene)
99
+
100
+ # 更新法线场景的姿态
101
+ normal_scene.set_pose(normal_mesh_node, rotation_matrix)
102
+
103
+ # 渲染法线图像
104
+ normal, _ = r.render(normal_scene, flags=pyrender.RenderFlags.FLAT)
105
+
106
+ # 拼接左右图像
107
+ combined_frame = np.concatenate((color, normal), axis=1)
108
+
109
+ # 写入视频帧
110
+ writer.append_data(combined_frame)
111
+ finally:
112
+ # 释放资源
113
+ writer.close()
114
+ r.delete()
115
+
116
+ print(f"Rendered video saved to {output_video_path}")
117
+ return output_video_path
118
+
119
+ if __name__ == '__main__':
120
+ # 示例调用
121
+ input_obj_path = "output/gradio_cache/text_3D/_超级赛亚人_10/rgb_projected.obj"
122
+ output_video_path = "output.mp4"
123
+ render_video_from_obj(input_obj_path, output_video_path)