smallville / TripoSR /obj_gen.py
CazC's picture
Update TripoSR/obj_gen.py
3aa18cd verified
import logging
import os
import tempfile
import time
import numpy as np
import rembg
import torch
from PIL import Image
from functools import partial
from tsr.system import TSR
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
import argparse
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
model = TSR.from_pretrained(
"stabilityai/TripoSR",
config_name="config.yaml",
weight_name="model.ckpt",
)
# adjust the chunk size to balance between speed and memory usage
model.renderer.set_chunk_size(8192)
model.to(device)
rembg_session = rembg.new_session()
def preprocess(input_image, do_remove_background, foreground_ratio):
def fill_background(image):
image = np.array(image).astype(np.float32) / 255.0
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
image = Image.fromarray((image * 255.0).astype(np.uint8))
return image
if do_remove_background:
image = input_image.convert("RGB")
image = remove_background(image, rembg_session)
image = resize_foreground(image, foreground_ratio)
image = fill_background(image)
else:
image = input_image
if image.mode == "RGBA":
image = fill_background(image)
return image
def generate(image, mc_resolution, formats=["obj", "glb"], path="output.obj"):
scene_codes = model(image, device=device)
mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
mesh = to_gradio_3d_orientation(mesh)
rv = []
for format in formats:
mesh_path = path.replace(".obj", f".{format}")
mesh.export(mesh_path)
rv.append(mesh_path)
return rv
def run_example(image_pil):
preprocessed = preprocess(image_pil, False, 0.9)
mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
return preprocessed, mesh_name_obj, mesh_name_glb
def generate_obj_from_image(image_pil, path="output.obj"):
try:
# Preprocess the image without removing the background and with a foreground ratio of 0.9
print("Preprocessing image")
preprocessed = preprocess(image_pil, True, 0.9)
print("Generating mesh")
# Generate the mesh and get the paths to the .obj and .glb files
mesh_paths = generate(preprocessed, 256, ["obj"], path)
except Exception as e:
print(f"Error generating mesh: {e}")
return None
# Return the path to the .obj file
return mesh_paths[0]
if __name__ == "__main__":
# run a test
image_path = "output.png"
image = Image.open(image_path)
generate_obj_from_image(image, "output.obj")
# move the .obj file to the output directory