Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
import torch.nn.functional as F | |
from gradio.themes.utils import sizes | |
from PIL import Image | |
from torchvision import transforms | |
import tempfile | |
class Config: | |
ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets') | |
CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints") | |
CHECKPOINTS = { | |
"0.3b": "sapiens_0.3b_normal_render_people_epoch_66_torchscript.pt2", | |
"0.6b": "sapiens_0.6b_normal_render_people_epoch_200_torchscript.pt2", | |
"1b": "sapiens_1b_normal_render_people_epoch_115_torchscript.pt2", | |
"2b": "sapiens_2b_normal_render_people_epoch_70_torchscript.pt2", | |
} | |
SEG_CHECKPOINTS = { | |
"fg-bg-1b (recommended)": "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2", | |
"no-bg-removal": None, | |
"part-seg-1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2", | |
} | |
class ModelManager: | |
def load_model(checkpoint_name: str): | |
if checkpoint_name is None: | |
return None | |
checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, checkpoint_name) | |
model = torch.jit.load(checkpoint_path) | |
model.eval() | |
model.to("cuda") | |
return model | |
def run_model(model, input_tensor, height, width): | |
output = model(input_tensor) | |
return F.interpolate(output, size=(height, width), mode="bilinear", align_corners=False) | |
class ImageProcessor: | |
def __init__(self): | |
self.transform_fn = transforms.Compose([ | |
transforms.Resize((1024, 768)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255], std=[58.5/255, 57.0/255, 57.5/255]), | |
]) | |
def process_image(self, image: Image.Image, normal_model_name: str, seg_model_name: str): | |
# Load models here instead of storing them as class attributes | |
normal_model = ModelManager.load_model(Config.CHECKPOINTS[normal_model_name]) | |
input_tensor = self.transform_fn(image).unsqueeze(0).to("cuda") | |
# Run normal estimation | |
normal_output = ModelManager.run_model(normal_model, input_tensor, image.height, image.width) | |
normal_map = normal_output.squeeze().cpu().numpy().transpose(1, 2, 0) | |
# Create a copy of the normal map for visualization | |
normal_map_vis = normal_map.copy() | |
# Run segmentation | |
if seg_model_name != "no-bg-removal": | |
seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name]) | |
seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width) | |
seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0] | |
# Apply segmentation mask to normal maps | |
normal_map[seg_mask == 0] = np.nan # Set background to NaN for NPY file | |
normal_map_vis[seg_mask == 0] = -1 # Set background to -1 for visualization | |
# Normalize and visualize normal map | |
normal_map_vis = self.visualize_normal_map(normal_map_vis) | |
# Create downloadable .npy file | |
npy_path = tempfile.mktemp(suffix='.npy') | |
np.save(npy_path, normal_map) | |
return Image.fromarray(normal_map_vis), npy_path | |
def visualize_normal_map(normal_map): | |
normal_map_norm = np.linalg.norm(normal_map, axis=-1, keepdims=True) | |
normal_map_normalized = normal_map / (normal_map_norm + 1e-5) | |
normal_map_vis = ((normal_map_normalized + 1) / 2 * 255).astype(np.uint8) | |
return normal_map_vis | |
class GradioInterface: | |
def __init__(self): | |
self.image_processor = ImageProcessor() | |
def create_interface(self): | |
app_styles = """ | |
<style> | |
/* Global Styles */ | |
body, #root { | |
font-family: Helvetica, Arial, sans-serif; | |
background-color: #1a1a1a; | |
color: #fafafa; | |
} | |
/* Header Styles */ | |
.app-header { | |
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%); | |
padding: 24px; | |
border-radius: 8px; | |
margin-bottom: 24px; | |
text-align: center; | |
} | |
.app-title { | |
font-size: 48px; | |
margin: 0; | |
color: #fafafa; | |
} | |
.app-subtitle { | |
font-size: 24px; | |
margin: 8px 0 16px; | |
color: #fafafa; | |
} | |
.app-description { | |
font-size: 16px; | |
line-height: 1.6; | |
opacity: 0.8; | |
margin-bottom: 24px; | |
} | |
/* Button Styles */ | |
.publication-links { | |
display: flex; | |
justify-content: center; | |
flex-wrap: wrap; | |
gap: 8px; | |
margin-bottom: 16px; | |
} | |
.publication-link { | |
display: inline-flex; | |
align-items: center; | |
padding: 8px 16px; | |
background-color: #333; | |
color: #fff !important; | |
text-decoration: none !important; | |
border-radius: 20px; | |
font-size: 14px; | |
transition: background-color 0.3s; | |
} | |
.publication-link:hover { | |
background-color: #555; | |
} | |
.publication-link i { | |
margin-right: 8px; | |
} | |
/* Content Styles */ | |
.content-container { | |
background-color: #2a2a2a; | |
border-radius: 8px; | |
padding: 24px; | |
margin-bottom: 24px; | |
} | |
/* Image Styles */ | |
.image-preview img { | |
max-width: 100%; | |
max-height: 512px; | |
margin: 0 auto; | |
border-radius: 4px; | |
display: block; | |
} | |
/* Control Styles */ | |
.control-panel { | |
background-color: #333; | |
padding: 16px; | |
border-radius: 8px; | |
margin-top: 16px; | |
} | |
/* Gradio Component Overrides */ | |
.gr-button { | |
background-color: #4a4a4a; | |
color: #fff; | |
border: none; | |
border-radius: 4px; | |
padding: 8px 16px; | |
cursor: pointer; | |
transition: background-color 0.3s; | |
} | |
.gr-button:hover { | |
background-color: #5a5a5a; | |
} | |
.gr-input, .gr-dropdown { | |
background-color: #3a3a3a; | |
color: #fff; | |
border: 1px solid #4a4a4a; | |
border-radius: 4px; | |
padding: 8px; | |
} | |
.gr-form { | |
background-color: transparent; | |
} | |
.gr-panel { | |
border: none; | |
background-color: transparent; | |
} | |
/* Override any conflicting styles from Bulma */ | |
.button.is-normal.is-rounded.is-dark { | |
color: #fff !important; | |
text-decoration: none !important; | |
} | |
</style> | |
""" | |
header_html = f""" | |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.3/css/bulma.min.css"> | |
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css"> | |
{app_styles} | |
<div class="app-header"> | |
<h1 class="app-title">Sapiens: Normal Estimation</h1> | |
<h2 class="app-subtitle">ECCV 2024 (Oral)</h2> | |
<p class="app-description"> | |
Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images. | |
This demo showcases the finetuned normal estimation model. <br> | |
Checkout other normal estimation baselines to compare: <a href="https://huggingface.co/spaces/Stable-X/normal-estimation-arena" style="color: #3273dc;">normal-estimation-arena</a> | |
</p> | |
<p style="font-size: 12px; opacity: 0.7;"> | |
Space modified from <a href="https://huggingface.co/spaces/fashn-ai/sapiens-body-part-segmentation" style="color: #3273dc;">fashn-ai</a> | |
</p> | |
<div class="publication-links"> | |
<a href="https://arxiv.org/abs/2408.12569" class="publication-link"> | |
<i class="fas fa-file-pdf"></i>arXiv | |
</a> | |
<a href="https://github.com/facebookresearch/sapiens" class="publication-link"> | |
<i class="fab fa-github"></i>Code | |
</a> | |
<a href="https://about.meta.com/realitylabs/codecavatars/sapiens/" class="publication-link"> | |
<i class="fas fa-globe"></i>Meta | |
</a> | |
<a href="https://rawalkhirodkar.github.io/sapiens" class="publication-link"> | |
<i class="fas fa-chart-bar"></i>Results | |
</a> | |
</div> | |
<div class="publication-links"> | |
<a href="https://huggingface.co/spaces/facebook/sapiens_pose" class="publication-link"> | |
<i class="fas fa-user"></i>Demo-Pose | |
</a> | |
<a href="https://huggingface.co/spaces/facebook/sapiens_seg" class="publication-link"> | |
<i class="fas fa-puzzle-piece"></i>Demo-Seg | |
</a> | |
<a href="https://huggingface.co/spaces/facebook/sapiens_depth" class="publication-link"> | |
<i class="fas fa-cube"></i>Demo-Depth | |
</a> | |
<a href="https://huggingface.co/spaces/facebook/sapiens_normal" class="publication-link"> | |
<i class="fas fa-vector-square"></i>Demo-Normal | |
</a> | |
</div> | |
</div> | |
""" | |
def process_image(image, normal_model_name, seg_model_name): | |
result, npy_path = self.image_processor.process_image(image, normal_model_name, seg_model_name) | |
return result, npy_path | |
js_func = """ | |
function refresh() { | |
const url = new URL(window.location); | |
if (url.searchParams.get('__theme') !== 'dark') { | |
url.searchParams.set('__theme', 'dark'); | |
window.location.href = url.href; | |
} | |
} | |
""" | |
with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo: | |
gr.HTML(header_html) | |
with gr.Row(elem_classes="content-container"): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview") | |
with gr.Row(elem_classes="control-panel"): | |
normal_model_name = gr.Dropdown( | |
label="Normal Model Size", | |
choices=list(Config.CHECKPOINTS.keys()), | |
value="1b", | |
) | |
seg_model_name = gr.Dropdown( | |
label="Background Removal Model", | |
choices=list(Config.SEG_CHECKPOINTS.keys()), | |
value="fg-bg-1b (recommended)", | |
) | |
example_model = gr.Examples( | |
inputs=input_image, | |
examples_per_page=14, | |
examples=[ | |
os.path.join(Config.ASSETS_DIR, "images", img) | |
for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images")) | |
], | |
) | |
with gr.Column(): | |
result_image = gr.Image(label="Normal Estimation Result", type="pil", elem_classes="image-preview") | |
npy_output = gr.File(label="Output (.npy). Note: Background normal is NaN.") | |
run_button = gr.Button("Run", elem_classes="gr-button") | |
run_button.click( | |
fn=process_image, | |
inputs=[input_image, normal_model_name, seg_model_name], | |
outputs=[result_image, npy_output], | |
) | |
return demo | |
def main(): | |
# Configure CUDA if available | |
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
interface = GradioInterface() | |
demo = interface.create_interface() | |
demo.launch(share=False) | |
if __name__ == "__main__": | |
main() |