tripo-custom / main.py
ashh757's picture
Update main.py
c88bb88 verified
import os
from utils import process_image, run_model
from typing import Tuple, Optional
from PIL import Image
from boto3 import Session
import torch
import pickle
import datetime
import gzip
# Retrieve credentials from environment variables
session = Session(
aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
region_name=os.getenv('AWS_DEFAULT_REGION')
)
s3 = session.client('s3')
def load_model():
with gzip.open('model_quantized_compressed.pkl.gz', 'rb') as f_in:
model_data = f_in.read()
model = pickle.loads(model_data)
print("Model Loaded")
return model
def upload_to_s3(file_path, bucket_name, s3_key):
with open(file_path, 'rb') as f:
s3.upload_fileobj(f, bucket_name, s3_key)
s3_url = f's3://{bucket_name}/{s3_key}'
return s3_url
def generate_mesh(image_path:str,
output_dir:str ='tmp/output/',
no_remove_bg:bool =True,
foreground_ratio:float =0.85 ,
render:bool =False ,
mc_resolution:int =256 ,
bake_texture_flag:bool =False ,
texture_resolution:int =2048,
model=None,
bucket_name:str=None,
input_folder:str=None,
output_folder:str=None,
input_s3_id:str='input_image.png',
output_s3_id:str='output_mesh.obj',
output_video_s3_id:str=None
) -> Tuple[Optional[str], Optional[str]] :
print('Process start')
image = process_image(image_path=image_path,
output_dir=output_dir ,
no_remove_bg=no_remove_bg ,
foreground_ratio=foreground_ratio)
print('Process end')
print('Run start')
output_file_path ,output_video_path = run_model(model=model,
image=image,
output_dir=output_dir ,
device="cuda:0" if torch.cuda.is_available() else "cpu",
render=render ,
mc_resolution=mc_resolution ,
model_save_format='obj',
bake_texture_flag=bake_texture_flag ,
texture_resolution=texture_resolution)
print('Run end')
print('Uploading to bucket...')
# Upload the input image and generated mesh file to S3
if input_folder != None:
input_s3_key = input_folder + '/' + input_s3_id
else:
input_s3_key = input_s3_id
if output_folder != None:
output_s3_key = output_folder + '/' + output_s3_id
else:
output_s3_key = output_s3_id
output_video_s3_key = output_video_s3_id
input_s3_url = upload_to_s3(image_path, bucket_name, input_s3_key)
output_s3_url = upload_to_s3(output_file_path, bucket_name, output_s3_key)
if output_video_path != None:
if output_folder != None:
output_video_s3_key = output_folder + '/' + output_video_s3_id
else:
output_video_s3_key = output_video_s3_id
output_video_s3_url = upload_to_s3(output_video_path, bucket_name, output_video_s3_key)
print(f'Files uploaded to S3:\nInput Image: {input_s3_url}\nOutput Mesh: {output_s3_url}\nOutput Video: {output_video_s3_url}')
return output_file_path ,output_video_path