Spaces:
Runtime error
Runtime error
import os | |
import tempfile | |
import time | |
from functools import lru_cache | |
from typing import Any | |
import boto3 | |
import gradio as gr | |
import numpy as np | |
import rembg | |
import torch | |
from gradio_litmodel3d import LitModel3D | |
from PIL import Image | |
from botocore.exceptions import NoCredentialsError, PartialCredentialsError | |
import sf3d.utils as sf3d_utils | |
from sf3d.system import SF3D | |
from fastapi import FastAPI, File, UploadFile | |
from fastapi.responses import FileResponse | |
import datetime | |
ACCESS = os.getenv("ACCESS") | |
SECRET = os.getenv("SECRET") | |
bedrock = boto3.client(service_name='bedrock', aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1') | |
bedrock_runtime = boto3.client(service_name='bedrock-runtime', aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1') | |
s3_client = boto3.client('s3',aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1') | |
app = FastAPI() | |
rembg_session = rembg.new_session() | |
COND_WIDTH = 512 | |
COND_HEIGHT = 512 | |
COND_DISTANCE = 1.6 | |
COND_FOVY_DEG = 40 | |
BACKGROUND_COLOR = [0.5, 0.5, 0.5] | |
# Cached. Doesn't change | |
c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE) | |
intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg( | |
COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH | |
) | |
model = SF3D.from_pretrained( | |
"stabilityai/stable-fast-3d", | |
config_name="config.yaml", | |
weight_name="model.safetensors", | |
) | |
model.eval().cuda() | |
example_files = [ | |
os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples") | |
] | |
def run_model(input_image): | |
start = time.time() | |
with torch.no_grad(): | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
model_batch = create_batch(input_image) | |
model_batch = {k: v.cuda() for k, v in model_batch.items()} | |
trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024) | |
trimesh_mesh = trimesh_mesh[0] | |
# Create new tmp file | |
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb") | |
trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True) | |
print("Generation took:", time.time() - start, "s") | |
return tmp_file.name | |
def create_batch(input_image: Image) -> dict[str, Any]: | |
img_cond = ( | |
torch.from_numpy( | |
np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) | |
/ 255.0 | |
) | |
.float() | |
.clip(0, 1) | |
) | |
mask_cond = img_cond[:, :, -1:] | |
rgb_cond = torch.lerp( | |
torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond | |
) | |
batch_elem = { | |
"rgb_cond": rgb_cond, | |
"mask_cond": mask_cond, | |
"c2w_cond": c2w_cond.unsqueeze(0), | |
"intrinsic_cond": intrinsic.unsqueeze(0), | |
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), | |
} | |
# Add batch dim | |
batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()} | |
return batched | |
def checkerboard(squares: int, size: int, min_value: float = 0.5): | |
base = np.zeros((squares, squares)) + min_value | |
base[1::2, ::2] = 1 | |
base[::2, 1::2] = 1 | |
repeat_mult = size // squares | |
return ( | |
base.repeat(repeat_mult, axis=0) | |
.repeat(repeat_mult, axis=1)[:, :, None] | |
.repeat(3, axis=-1) | |
) | |
def remove_background(input_image: Image) -> Image: | |
return rembg.remove(input_image, session=rembg_session) | |
def resize_foreground( | |
image: Image, | |
ratio: float, | |
) -> Image: | |
image = np.array(image) | |
assert image.shape[-1] == 4 | |
alpha = np.where(image[..., 3] > 0) | |
y1, y2, x1, x2 = ( | |
alpha[0].min(), | |
alpha[0].max(), | |
alpha[1].min(), | |
alpha[1].max(), | |
) | |
# crop the foreground | |
fg = image[y1:y2, x1:x2] | |
# pad to square | |
size = max(fg.shape[0], fg.shape[1]) | |
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 | |
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 | |
new_image = np.pad( | |
fg, | |
((ph0, ph1), (pw0, pw1), (0, 0)), | |
mode="constant", | |
constant_values=((0, 0), (0, 0), (0, 0)), | |
) | |
# compute padding according to the ratio | |
new_size = int(new_image.shape[0] / ratio) | |
# pad to size, double side | |
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 | |
ph1, pw1 = new_size - size - ph0, new_size - size - pw0 | |
new_image = np.pad( | |
new_image, | |
((ph0, ph1), (pw0, pw1), (0, 0)), | |
mode="constant", | |
constant_values=((0, 0), (0, 0), (0, 0)), | |
) | |
new_image = Image.fromarray(new_image, mode="RGBA").resize( | |
(COND_WIDTH, COND_HEIGHT) | |
) | |
return new_image | |
def square_crop(input_image: Image) -> Image: | |
# Perform a center square crop | |
min_size = min(input_image.size) | |
left = (input_image.size[0] - min_size) // 2 | |
top = (input_image.size[1] - min_size) // 2 | |
right = (input_image.size[0] + min_size) // 2 | |
bottom = (input_image.size[1] + min_size) // 2 | |
return input_image.crop((left, top, right, bottom)).resize( | |
(COND_WIDTH, COND_HEIGHT) | |
) | |
def show_mask_img(input_image: Image) -> Image: | |
img_numpy = np.array(input_image) | |
alpha = img_numpy[:, :, 3] / 255.0 | |
chkb = checkerboard(32, 512) * 255 | |
new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None]) | |
return Image.fromarray(new_img.astype(np.uint8), mode="RGB") | |
def upload_file_to_s3(file_path, bucket_name, object_name=None): | |
s3_client.upload_file(file_path, bucket_name, object_name) | |
return True | |
async def process_image(file: UploadFile = File(...), foreground_ratio: float = 0.85): | |
input_image = Image.open(file.file).convert("RGBA") | |
rem_removed = remove_background(input_image) | |
sqr_crop = square_crop(rem_removed) | |
fr_res = resize_foreground(sqr_crop, foreground_ratio) | |
glb_file = run_model(fr_res) | |
timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f') | |
object_name = f'object_{timestamp}.glb' | |
if upload_file_to_s3(glb_file, 'framebucket3d',object_name): | |
return { | |
"glb_path": f"https://framebucket3d.s3.amazonaws.com/{object_name}" | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |