Update gradio_app.py
Browse files- gradio_app.py +79 -24
gradio_app.py
CHANGED
@@ -14,6 +14,7 @@ from collections import OrderedDict
|
|
14 |
import trimesh
|
15 |
import gradio as gr
|
16 |
from typing import Any
|
|
|
17 |
|
18 |
proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
19 |
sys.path.append(os.path.join(proj_dir))
|
@@ -58,6 +59,8 @@ If you have any questions, feel free to open a discussion or contact us at <b>we
|
|
58 |
"""
|
59 |
from apps.third_party.CRM.pipelines import TwoStagePipeline
|
60 |
from apps.third_party.LGM.pipeline_mvdream import MVDreamPipeline
|
|
|
|
|
61 |
|
62 |
import re
|
63 |
import os
|
@@ -88,22 +91,25 @@ chmod(f"{parent_dir}/apps/third_party/InstantMeshes", "777")
|
|
88 |
|
89 |
model = None
|
90 |
cached_dir = None
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
stage1_model_config.resume = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth", repo_type="model")
|
95 |
-
stage1_model_config.config = f"{parent_dir}/apps/third_party/CRM/" + stage1_model_config.config
|
96 |
crm_pipeline = None
|
97 |
|
98 |
sys.path.append(f"apps/third_party/LGM")
|
99 |
imgaedream_pipeline = None
|
100 |
|
|
|
|
|
|
|
101 |
@spaces.GPU
|
102 |
def gen_mvimg(
|
103 |
mvimg_model, image, seed, guidance_scale, step, text, neg_text, elevation, backgroud_color
|
104 |
):
|
105 |
if seed == 0:
|
106 |
seed = np.random.randint(1, 65535)
|
|
|
|
|
107 |
|
108 |
if mvimg_model == "CRM":
|
109 |
global crm_pipeline
|
@@ -118,7 +124,7 @@ def gen_mvimg(
|
|
118 |
return mv_imgs[5], mv_imgs[3], mv_imgs[2], mv_imgs[0]
|
119 |
|
120 |
elif mvimg_model == "ImageDream":
|
121 |
-
global imagedream_pipeline
|
122 |
background = Image.new("RGBA", image.size, backgroud_color)
|
123 |
image = Image.alpha_composite(background, image)
|
124 |
image = np.array(image).astype(np.float32) / 255.0
|
@@ -130,9 +136,36 @@ def gen_mvimg(
|
|
130 |
guidance_scale=guidance_scale,
|
131 |
num_inference_steps=step,
|
132 |
elevation=elevation,
|
|
|
133 |
)
|
134 |
return mv_imgs[1], mv_imgs[2], mv_imgs[3], mv_imgs[0]
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
@spaces.GPU
|
138 |
def image2mesh(view_front: np.ndarray,
|
@@ -209,24 +242,46 @@ if __name__=="__main__":
|
|
209 |
"Auto Remove Background": "Auto Remove Background",
|
210 |
"Original Image": "Original Image",
|
211 |
})
|
212 |
-
mvimg_model_config_list = [
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
# for 3D latent set diffusion
|
226 |
-
ckpt_path = "
|
227 |
-
config_path = "
|
228 |
-
# ckpt_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/model.ckpt", repo_type="model")
|
229 |
-
# config_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/config.yaml", repo_type="model")
|
230 |
# ckpt_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/model-300k.ckpt", repo_type="model")
|
231 |
# config_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/config.yaml", repo_type="model")
|
232 |
scheluder_dict = OrderedDict({
|
@@ -266,7 +321,7 @@ if __name__=="__main__":
|
|
266 |
gr.Markdown('''Try a different <b>seed and MV Model</b> for better results. Good Luck :)''')
|
267 |
with gr.Row():
|
268 |
seed = gr.Number(0, label='Seed', show_label=True)
|
269 |
-
mvimg_model = gr.Dropdown(value="
|
270 |
more = gr.CheckboxGroup(["Remesh", "Symmetry(TBD)"], label="More", show_label=False)
|
271 |
with gr.Row():
|
272 |
# input prompt
|
|
|
14 |
import trimesh
|
15 |
import gradio as gr
|
16 |
from typing import Any
|
17 |
+
from einops import rearrange
|
18 |
|
19 |
proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
20 |
sys.path.append(os.path.join(proj_dir))
|
|
|
59 |
"""
|
60 |
from apps.third_party.CRM.pipelines import TwoStagePipeline
|
61 |
from apps.third_party.LGM.pipeline_mvdream import MVDreamPipeline
|
62 |
+
from apps.third_party.Era3D.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
|
63 |
+
from apps.third_party.Era3D.data.single_image_dataset import SingleImageDataset
|
64 |
|
65 |
import re
|
66 |
import os
|
|
|
91 |
|
92 |
model = None
|
93 |
cached_dir = None
|
94 |
+
generator = None
|
95 |
+
|
96 |
+
sys.path.append(f"apps/third_party/CRM")
|
|
|
|
|
97 |
crm_pipeline = None
|
98 |
|
99 |
sys.path.append(f"apps/third_party/LGM")
|
100 |
imgaedream_pipeline = None
|
101 |
|
102 |
+
sys.path.append(f"apps/third_party/Era3D")
|
103 |
+
era3d_pipeline = None
|
104 |
+
|
105 |
@spaces.GPU
|
106 |
def gen_mvimg(
|
107 |
mvimg_model, image, seed, guidance_scale, step, text, neg_text, elevation, backgroud_color
|
108 |
):
|
109 |
if seed == 0:
|
110 |
seed = np.random.randint(1, 65535)
|
111 |
+
global generator
|
112 |
+
generator.manual_seed(seed)
|
113 |
|
114 |
if mvimg_model == "CRM":
|
115 |
global crm_pipeline
|
|
|
124 |
return mv_imgs[5], mv_imgs[3], mv_imgs[2], mv_imgs[0]
|
125 |
|
126 |
elif mvimg_model == "ImageDream":
|
127 |
+
global imagedream_pipeline
|
128 |
background = Image.new("RGBA", image.size, backgroud_color)
|
129 |
image = Image.alpha_composite(background, image)
|
130 |
image = np.array(image).astype(np.float32) / 255.0
|
|
|
136 |
guidance_scale=guidance_scale,
|
137 |
num_inference_steps=step,
|
138 |
elevation=elevation,
|
139 |
+
generator=generator,
|
140 |
)
|
141 |
return mv_imgs[1], mv_imgs[2], mv_imgs[3], mv_imgs[0]
|
142 |
+
|
143 |
+
elif mvimg_model == "Era3D":
|
144 |
+
global era3d_pipeline
|
145 |
+
crop_size = 420
|
146 |
+
batch = SingleImageDataset(root_dir='', num_views=6, img_wh=[512, 512], bg_color='white',
|
147 |
+
crop_size=crop_size, single_image=image, prompt_embeds_path='apps/third_party/Era3D/data/fixed_prompt_embeds_6view')[0]
|
148 |
+
imgs_in = torch.cat([batch['imgs_in']]*2, dim=0)
|
149 |
+
imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W)
|
150 |
+
|
151 |
+
normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings']
|
152 |
+
prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
|
153 |
+
prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
|
154 |
+
|
155 |
+
imgs_in = imgs_in.to(device=device, dtype=torch.float16)
|
156 |
+
prompt_embeddings = prompt_embeddings.to(device=device, dtype=torch.float16)
|
157 |
+
|
158 |
+
mv_imgs = era3d_pipeline(
|
159 |
+
imgs_in,
|
160 |
+
None,
|
161 |
+
prompt_embeds=prompt_embeddings,
|
162 |
+
generator=generator,
|
163 |
+
guidance_scale=guidance_scale,
|
164 |
+
num_inference_steps=step,
|
165 |
+
num_images_per_prompt=1,
|
166 |
+
**{'eta': 1.0}
|
167 |
+
).images
|
168 |
+
return mv_imgs[6], mv_imgs[8], mv_imgs[9], mv_imgs[10]
|
169 |
|
170 |
@spaces.GPU
|
171 |
def image2mesh(view_front: np.ndarray,
|
|
|
242 |
"Auto Remove Background": "Auto Remove Background",
|
243 |
"Original Image": "Original Image",
|
244 |
})
|
245 |
+
mvimg_model_config_list = [
|
246 |
+
"Era3D",
|
247 |
+
# "CRM",
|
248 |
+
# "ImageDream"
|
249 |
+
]
|
250 |
+
if "Era3D" in mvimg_model_config_list:
|
251 |
+
# cfg = load_config("apps/third_party/Era3D/configs/test_unclip-512-6view.yaml")
|
252 |
+
# schema = OmegaConf.structured(TestConfig)
|
253 |
+
# cfg = OmegaConf.merge(schema, cfg)
|
254 |
+
era3d_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
|
255 |
+
'pengHTYX/MacLab-Era3D-512-6view',
|
256 |
+
torch_dtype=torch.float16
|
257 |
+
)
|
258 |
+
# enable xformers
|
259 |
+
era3d_pipeline.unet.enable_xformers_memory_efficient_attention()
|
260 |
+
era3d_pipeline.to(device)
|
261 |
+
elif "CRM" in mvimg_model_config_list:
|
262 |
+
stage1_config = OmegaConf.load(f"apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml").config
|
263 |
+
stage1_sampler_config = stage1_config.sampler
|
264 |
+
stage1_model_config = stage1_config.models
|
265 |
+
stage1_model_config.resume = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth", repo_type="model")
|
266 |
+
stage1_model_config.config = f"apps/third_party/CRM/" + stage1_model_config.config
|
267 |
+
crm_pipeline = TwoStagePipeline(
|
268 |
+
stage1_model_config,
|
269 |
+
stage1_sampler_config,
|
270 |
+
device=device,
|
271 |
+
dtype=torch.float16
|
272 |
+
)
|
273 |
+
elif "ImageDream" in mvimg_model_config_list:
|
274 |
+
imagedream_pipeline = MVDreamPipeline.from_pretrained(
|
275 |
+
"ashawkey/imagedream-ipmv-diffusers", # remote weights
|
276 |
+
torch_dtype=torch.float16,
|
277 |
+
trust_remote_code=True,
|
278 |
+
)
|
279 |
+
|
280 |
+
generator = torch.Generator(device)
|
281 |
|
282 |
# for 3D latent set diffusion
|
283 |
+
ckpt_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6-aligned-vae/model.ckpt", repo_type="model")
|
284 |
+
config_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6-aligned-vae/config.yaml", repo_type="model")
|
|
|
|
|
285 |
# ckpt_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/model-300k.ckpt", repo_type="model")
|
286 |
# config_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/config.yaml", repo_type="model")
|
287 |
scheluder_dict = OrderedDict({
|
|
|
321 |
gr.Markdown('''Try a different <b>seed and MV Model</b> for better results. Good Luck :)''')
|
322 |
with gr.Row():
|
323 |
seed = gr.Number(0, label='Seed', show_label=True)
|
324 |
+
mvimg_model = gr.Dropdown(value="Era3D", label="MV Image Model", choices=list(mvimg_model_config_list))
|
325 |
more = gr.CheckboxGroup(["Remesh", "Symmetry(TBD)"], label="More", show_label=False)
|
326 |
with gr.Row():
|
327 |
# input prompt
|