import os import tempfile import time from functools import lru_cache from typing import Any import gradio as gr import numpy as np import rembg import torch from gradio_litmodel3d import LitModel3D from PIL import Image import sf3d.utils as sf3d_utils from sf3d.system import SF3D rembg_session = rembg.new_session() COND_WIDTH = 512 COND_HEIGHT = 512 COND_DISTANCE = 1.6 COND_FOVY_DEG = 40 BACKGROUND_COLOR = [0.5, 0.5, 0.5] c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE) intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg( COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH ) model = SF3D.from_pretrained( "stabilityai/stable-fast-3d", config_name="config.yaml", weight_name="model.safetensors", ) model.eval().cuda() example_files = [os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")] def run_model(input_image): start = time.time() with torch.no_grad(): with torch.autocast(device_type="cuda", dtype=torch.float16): model_batch = create_batch(input_image) model_batch = {k: v.cuda() for k, v in model_batch.items()} trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024) trimesh_mesh = trimesh_mesh[0] tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb") trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True) print("Generation took:", time.time() - start, "s") return tmp_file.name def create_batch(input_image: Image) -> dict[str, Any]: img_cond = ( torch.from_numpy( np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) / 255.0 ) .float() .clip(0, 1) ) mask_cond = img_cond[:, :, -1:] rgb_cond = torch.lerp( torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond ) batch_elem = { "rgb_cond": rgb_cond, "mask_cond": mask_cond, "c2w_cond": c2w_cond.unsqueeze(0), "intrinsic_cond": intrinsic.unsqueeze(0), "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), } batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()} return batched @lru_cache def checkerboard(squares: int, size: int, min_value: float = 0.5): base = np.zeros((squares, squares)) + min_value base[1::2, ::2] = 1 base[::2, 1::2] = 1 repeat_mult = size // squares return ( base.repeat(repeat_mult, axis=0) .repeat(repeat_mult, axis=1)[:, :, None] .repeat(3, axis=-1) ) def remove_background(input_image: Image) -> Image: return rembg.remove(input_image, session=rembg_session) def resize_foreground(image: Image, ratio: float) -> Image: image = np.array(image) assert image.shape[-1] == 4 alpha = np.where(image[..., 3] > 0) y1, y2, x1, x2 = alpha[0].min(), alpha[0].max(), alpha[1].min(), alpha[1].max() fg = image[y1:y2, x1:x2] size = max(fg.shape[0], fg.shape[1]) ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 new_image = np.pad( fg, ((ph0, ph1), (pw0, pw1), (0, 0)), mode="constant", constant_values=0, ) new_size = int(new_image.shape[0] / ratio) ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 ph1, pw1 = new_size - size - ph0, new_size - size - pw0 new_image = np.pad( new_image, ((ph0, ph1), (pw0, pw1), (0, 0)), mode="constant", constant_values=0, ) new_image = Image.fromarray(new_image, mode="RGBA").resize((COND_WIDTH, COND_HEIGHT)) return new_image def square_crop(input_image: Image) -> Image: min_size = min(input_image.size) left = (input_image.size[0] - min_size) // 2 top = (input_image.size[1] - min_size) // 2 right = (input_image.size[0] + min_size) // 2 bottom = (input_image.size[1] + min_size) // 2 return input_image.crop((left, top, right, bottom)).resize((COND_WIDTH, COND_HEIGHT)) def show_mask_img(input_image: Image) -> Image: img_numpy = np.array(input_image) alpha = img_numpy[:, :, 3] / 255.0 chkb = checkerboard(32, 512) * 255 new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None]) return Image.fromarray(new_img.astype(np.uint8), mode="RGB") def run_button(run_btn, input_image, background_state, foreground_ratio): if run_btn == "Run": glb_file: str = run_model(background_state) return ( gr.update(), gr.update(), gr.update(), gr.update(), gr.update(value=glb_file, visible=True), gr.update(visible=True), ) elif run_btn == "Remove Background": rem_removed = remove_background(input_image) sqr_crop = square_crop(rem_removed) fr_res = resize_foreground(sqr_crop, foreground_ratio) return ( gr.update(value="Run", visible=True), sqr_crop, fr_res, gr.update(value=show_mask_img(fr_res), visible=True), gr.update(value=None, visible=False), gr.update(visible=False), ) def requires_bg_remove(image, fr): if image is None: return ( gr.update(visible=False, value="Run"), None, None, gr.update(value=None, visible=False), gr.update(visible=False), gr.update(visible=False), ) alpha_channel = np.array(image.getchannel("A")) min_alpha = alpha_channel.min() if min_alpha == 0: sqr_crop = square_crop(image) fr_res = resize_foreground(sqr_crop, fr) return ( gr.update(value="Run", visible=True), sqr_crop, fr_res, gr.update(value=show_mask_img(fr_res), visible=True), gr.update(visible=False), gr.update(visible=False), ) return ( gr.update(value="Remove Background", visible=True), None, None, gr.update(value=None, visible=False), gr.update(visible=False), gr.update(visible=False), ) def update_foreground_ratio(img_proc, fr): foreground_res = resize_foreground(img_proc, fr) return ( foreground_res, gr.update(value=show_mask_img(foreground_res)), ) class CustomTheme(gr.themes.Base): def __init__(self): super().__init__() self.primary_hue = "#191a1e" self.background_fill_primary = "#191a1e" self.background_fill_secondary = "#191a1e" self.background_fill_tertiary = "#191a1e" self.text_color_primary = "#FFFFFF" self.text_color_secondary = "#FFFFFF" self.text_color_tertiary = "#FFFFFF" self.input_background_fill = "#191a1e" self.input_text_color = "#FFFFFF" self.font = ( "Poppins", "https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;700&display=swap", ) css = """ :root, html, body, #root, #__next, body > div, body > div > div, body > div > div > div { background-color: #191a1e !important; margin: 0 !important; padding: 0 !important; border: none !important; width: 100% !important; min-height: 100%; height: 100%; box-sizing: border-box; } body { overflow-x: hidden !important; } .gradio-container, .gr-app { background-color: #191a1e !important; width: 100% !important; margin: 0 !important; padding: 0 !important; box-sizing: border-box; display: flex; flex-direction: column; } /* Снятие ограничений по ширине с внутренних контейнеров */ .gr-block, .gr-box { max-width: 100% !important; width: 100% !important; } /* Если какие-то элементы имели свой фон, делаем прозрачным */ * { background-color: transparent !important; } /* Ряды и колонки */ .gr-row { display: flex; flex-wrap: wrap; gap: 20px; align-items: flex-start; width: 100%; box-sizing: border-box; } .gr-column { flex: 1 1 auto; min-width: 300px; display: flex; flex-direction: column; gap: 10px; width: 100%; box-sizing: border-box; } /* Уменьшаем размеры изображений */ .gr-image img { max-width: 300px !important; max-height: 300px !important; object-fit: contain; } /* Кнопка */ .generate-button { background-color: #5271FF !important; color: #FFFFFF !important; border: none !important; font-weight: bold !important; font-size: 1.2em !important; padding: 0.75em 2em !important; border-radius: 0.3em !important; } /* Примеры сразу под кнопкой Run */ .gr-examples { display: flex !important; flex-wrap: wrap !important; gap: 10px !important; justify-content: flex-start !important; align-items: center !important; } .gr-examples img { max-height: 64px; object-fit: contain; } footer { display: none !important; } """ with gr.Blocks(theme=CustomTheme(), css=css) as demo: img_proc_state = gr.State() background_remove_state = gr.State() with gr.Row(): with gr.Column(): with gr.Row(): input_img = gr.Image( type="pil", label="Input Image", sources="upload", image_mode="RGBA", ) preview_removal = gr.Image( label="Preview Background Removal", type="pil", image_mode="RGB", interactive=False, visible=False, ) foreground_ratio = gr.Slider( label="Foreground Ratio", minimum=0.5, maximum=1.0, value=0.85, step=0.05, ) foreground_ratio.change( update_foreground_ratio, inputs=[img_proc_state, foreground_ratio], outputs=[background_remove_state, preview_removal], ) run_btn = gr.Button( "Run", variant="primary", visible=False, elem_classes="generate-button", ) examples = gr.Examples( examples=example_files, inputs=input_img, ) with gr.Column(): output_3d = LitModel3D( label="3D Model", visible=False, clear_color=[0.0, 0.0, 0.0, 0.0], tonemapping="aces", contrast=1.0, scale=1.0, ) with gr.Column(visible=False, scale=1.0) as hdr_row: gr.Markdown( """ ## HDR Environment Map Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps. """ ) with gr.Row(): hdr_illumination_file = gr.File( label="HDR Env Map", file_types=[".hdr"], file_count="single" ) example_hdris = [ os.path.join("demo_files/hdri", f) for f in os.listdir("demo_files/hdri") ] hdr_illumination_example = gr.Examples( examples=example_hdris, inputs=hdr_illumination_file, ) hdr_illumination_file.change( lambda x: gr.update(env_map=x.name if x is not None else None), inputs=hdr_illumination_file, outputs=[output_3d], ) input_img.change( requires_bg_remove, inputs=[input_img, foreground_ratio], outputs=[ run_btn, img_proc_state, background_remove_state, preview_removal, output_3d, hdr_row, ], ) run_btn.click( run_button, inputs=[ run_btn, input_img, background_remove_state, foreground_ratio, ], outputs=[ run_btn, img_proc_state, background_remove_state, preview_removal, output_3d, hdr_row, ], ) demo.launch()