Spaces:
Building
on
T4
Building
on
T4
File size: 4,332 Bytes
71d12ce caef638 71d12ce cf3fc03 71d12ce 6b3c1e9 c49ce5c 71d12ce 269cf5b c49ce5c 71d12ce b796e0c 942501f 71d12ce b03eeaf 71d12ce cf75aba 71d12ce eee51b6 71d12ce e245f8d 71d12ce a4d4a45 71d12ce 2acae5e 71d12ce 2acae5e 71d12ce f107a56 a4d4a45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from diffusers import StableDiffusionLDM3DPipeline
import gradio as gr
import torch
from PIL import Image
import base64
from io import BytesIO
from tempfile import NamedTemporaryFile
from pathlib import Path
Path("tmp").mkdir(exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device is {device}")
torch_type = torch.float16 if device == "cuda" else torch.float32
pipe = StableDiffusionLDM3DPipeline.from_pretrained(
"Intel/ldm3d-pano",
torch_dtype=torch_type
# , safety_checker=None
)
pipe.to(device)
if device == "cuda":
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()
def get_iframe(rgb_path: str, depth_path: str, viewer_mode: str = "6DOF"):
# buffered = BytesIO()
# rgb.convert("RGB").save(buffered, format="JPEG")
# rgb_base64 = base64.b64encode(buffered.getvalue())
# buffered = BytesIO()
# depth.convert("RGB").save(buffered, format="JPEG")
# depth_base64 = base64.b64encode(buffered.getvalue())
# rgb_base64 = "data:image/jpeg;base64," + rgb_base64.decode("utf-8")
# depth_base64 = "data:image/jpeg;base64," + depth_base64.decode("utf-8")
rgb_base64 = f"/file={rgb_path}"
depth_base64 = f"/file={depth_path}"
if viewer_mode == "6DOF":
return f"""<iframe src="file=static/three6dof.html" width="100%" height="500px" data-rgb="{rgb_base64}" data-depth="{depth_base64}"></iframe>"""
else:
return f"""<iframe src="file=static/depthmap.html" width="100%" height="500px" data-rgb="{rgb_base64}" data-depth="{depth_base64}"></iframe>"""
def predict(
prompt: str,
negative_prompt: str,
guidance_scale: float = 5.0,
seed: int = 0,
randomize_seed: bool = True,
):
generator = torch.Generator() if randomize_seed else torch.manual_seed(seed)
output = pipe(
prompt,
width=1024,
height=512,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
generator=generator,
num_inference_steps=50,
) # type: ignore
rgb_image, depth_image = output.rgb[0], output.depth[0] # type: ignore
with NamedTemporaryFile(suffix=".png", delete=False, dir="tmp") as rgb_file:
rgb_image.save(rgb_file.name)
rgb_image = rgb_file.name
with NamedTemporaryFile(suffix=".png", delete=False, dir="tmp") as depth_file:
depth_image.save(depth_file.name)
depth_image = depth_file.name
iframe = get_iframe(rgb_image, depth_image)
return rgb_image, depth_image, generator.seed(), iframe
with gr.Blocks() as block:
gr.Markdown(
"""
## LDM3d Demo
[Model card](https://huggingface.co/Intel/ldm3d-pano)
[Diffusers docs](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/ldm3d_diffusion)
For better results, specify "360 view of" or "panoramic view of" in the prompt
"""
)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt")
negative_prompt = gr.Textbox(label="Negative Prompt")
guidance_scale = gr.Slider(
label="Guidance Scale", minimum=0, maximum=10, step=0.1, value=5.0
)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
seed = gr.Slider(label="Seed", minimum=0,
maximum=2**64 - 1, step=1)
generated_seed = gr.Number(label="Generated Seed")
markdown = gr.Markdown(label="Output Box")
with gr.Row():
new_btn = gr.Button("New Image")
with gr.Column(scale=2):
html = gr.HTML()
with gr.Row():
rgb = gr.Image(label="RGB Image", type="filepath")
depth = gr.Image(label="Depth Image", type="filepath")
gr.Examples(
examples=[
["360 view of a large bedroom", "", 7.0, 42, False]],
inputs=[prompt, negative_prompt, guidance_scale, seed, randomize_seed],
outputs=[rgb, depth, generated_seed, html],
fn=predict,
cache_examples=True)
new_btn.click(
fn=predict,
inputs=[prompt, negative_prompt, guidance_scale, seed, randomize_seed],
outputs=[rgb, depth, generated_seed, html],
)
block.launch(
allowed_paths=["assets/", "static/", "tmp/"]
)
|