Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import tempfile | |
import numpy as np | |
import torch | |
import trimesh | |
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
from shap_e.diffusion.sample import sample_latents | |
from shap_e.models.download import load_config, load_model | |
from shap_e.models.nn.camera import (DifferentiableCameraBatch, | |
DifferentiableProjectiveCamera) | |
from shap_e.models.transmitter.base import Transmitter, VectorDecoder | |
from shap_e.rendering.torch_mesh import TorchMesh | |
from shap_e.util.collections import AttrDict | |
from shap_e.util.image_util import load_image | |
# Copied from https://github.com/openai/shap-e/blob/d99cedaea18e0989e340163dbaeb4b109fa9e8ec/shap_e/util/notebooks.py#L15-L42 | |
def create_pan_cameras(size: int, | |
device: torch.device) -> DifferentiableCameraBatch: | |
origins = [] | |
xs = [] | |
ys = [] | |
zs = [] | |
for theta in np.linspace(0, 2 * np.pi, num=20): | |
z = np.array([np.sin(theta), np.cos(theta), -0.5]) | |
z /= np.sqrt(np.sum(z**2)) | |
origin = -z * 4 | |
x = np.array([np.cos(theta), -np.sin(theta), 0.0]) | |
y = np.cross(z, x) | |
origins.append(origin) | |
xs.append(x) | |
ys.append(y) | |
zs.append(z) | |
return DifferentiableCameraBatch( | |
shape=(1, len(xs)), | |
flat_camera=DifferentiableProjectiveCamera( | |
origin=torch.from_numpy(np.stack(origins, | |
axis=0)).float().to(device), | |
x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device), | |
y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device), | |
z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device), | |
width=size, | |
height=size, | |
x_fov=0.7, | |
y_fov=0.7, | |
), | |
) | |
# Copied from https://github.com/openai/shap-e/blob/8625e7c15526d8510a2292f92165979268d0e945/shap_e/util/notebooks.py#LL64C1-L76C33 | |
def decode_latent_mesh( | |
xm: Transmitter | VectorDecoder, | |
latent: torch.Tensor, | |
) -> TorchMesh: | |
decoded = xm.renderer.render_views( | |
AttrDict(cameras=create_pan_cameras( | |
2, latent.device)), # lowest resolution possible | |
params=(xm.encoder if isinstance(xm, Transmitter) else | |
xm).bottleneck_to_params(latent[None]), | |
options=AttrDict(rendering_mode='stf', render_with_direction=False), | |
) | |
return decoded.raw_meshes[0] | |
class Model: | |
def __init__(self): | |
self.device = torch.device( | |
'cuda' if torch.cuda.is_available() else 'cpu') | |
self.xm = load_model('transmitter', device=self.device) | |
self.diffusion = diffusion_from_config(load_config('diffusion')) | |
self.model_name = '' | |
self.model = None | |
def load_model(self, model_name: str) -> None: | |
assert model_name in ['text300M', 'image300M'] | |
if model_name == self.model_name: | |
return | |
self.model = load_model(model_name, device=self.device) | |
self.model_name = model_name | |
gc.collect() | |
torch.cuda.empty_cache() | |
def to_glb(self, latent: torch.Tensor) -> str: | |
ply_path = tempfile.NamedTemporaryFile(suffix='.ply', | |
delete=False, | |
mode='w+b') | |
decode_latent_mesh(self.xm, latent).tri_mesh().write_ply(ply_path) | |
mesh = trimesh.load(ply_path.name) | |
rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0]) | |
mesh = mesh.apply_transform(rot) | |
rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0]) | |
mesh = mesh.apply_transform(rot) | |
mesh_path = tempfile.NamedTemporaryFile(suffix='.glb', delete=False) | |
mesh.export(mesh_path.name, file_type='glb') | |
return mesh_path.name | |
def run_text(self, | |
prompt: str, | |
seed: int = 0, | |
guidance_scale: float = 15.0, | |
num_steps: int = 64) -> str: | |
self.load_model('text300M') | |
torch.manual_seed(seed) | |
latents = sample_latents( | |
batch_size=1, | |
model=self.model, | |
diffusion=self.diffusion, | |
guidance_scale=guidance_scale, | |
model_kwargs=dict(texts=[prompt]), | |
progress=True, | |
clip_denoised=True, | |
use_fp16=True, | |
use_karras=True, | |
karras_steps=num_steps, | |
sigma_min=1e-3, | |
sigma_max=160, | |
s_churn=0, | |
) | |
return self.to_glb(latents[0]) | |
def run_image(self, | |
image_path: str, | |
seed: int = 0, | |
guidance_scale: float = 3.0, | |
num_steps: int = 64) -> str: | |
self.load_model('image300M') | |
torch.manual_seed(seed) | |
image = load_image(image_path) | |
latents = sample_latents( | |
batch_size=1, | |
model=self.model, | |
diffusion=self.diffusion, | |
guidance_scale=guidance_scale, | |
model_kwargs=dict(images=[image]), | |
progress=True, | |
clip_denoised=True, | |
use_fp16=True, | |
use_karras=True, | |
karras_steps=num_steps, | |
sigma_min=1e-3, | |
sigma_max=160, | |
s_churn=0, | |
) | |
return self.to_glb(latents[0]) | |