Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from PIL.ImageDraw import Draw | |
from diffusers import StableDiffusionPipeline | |
from PIL import Image, ImageOps | |
# Load pipeline once | |
model_id = '/Users/tomerkeren/DeciDiffusion-v1-0' | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, custom_pipeline=model_id, torch_dtype=torch.float32) | |
pipe.unet = pipe.unet.from_pretrained(model_id, subfolder='flexible_unet', torch_dtype=torch.float32) | |
pipe = pipe.to(device) | |
def read_content(file_path: str) -> str: | |
"""read the content of target file | |
""" | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
def predict(_prompt: str, _steps: int = 30, _seed: int = 42, _guidance_scale: float = 7.5, _negative_prompt: str = ""): | |
_negative_prompt = [_negative_prompt] if _negative_prompt else None | |
output = pipe(prompt=[_prompt], | |
negative_prompt=_negative_prompt, | |
num_inference_steps=int(_steps), | |
guidance_scale=_guidance_scale, | |
generator=torch.Generator(device).manual_seed(_seed), | |
) | |
output_image = output.images[0] | |
# Add border beneath the image with Deci logo + prompt | |
if len(_prompt) > 52: | |
_prompt = _prompt[:52] + "..." | |
original_image_height = output_image.size[1] | |
output_image = ImageOps.expand(output_image, border=(0, 0, 0, 64), fill='white') | |
deci_logo = Image.open('https://huggingface.co/spaces/Deci/DeciDiffusion-v1-0/resolve/main/deci_logo_white.png') | |
output_image.paste(deci_logo, (0, original_image_height)) | |
Draw(output_image).text((deci_logo.size[0], original_image_height), _prompt, (127, 127, 127)) | |
return output_image | |
css = ''' | |
.gradio-container { | |
max-width: 1100px !important; | |
background-image: url(https://huggingface.co/spaces/Deci/Deci-DeciDiffusionClean/resolve/main/background-image.png); | |
background-size: cover; | |
background-position: center center; | |
background-repeat: no-repeat; | |
} | |
.footer {margin-bottom: 45px;margin-top: 35px !important;text-align: center;border-bottom: 1px solid #e5e5e5} | |
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white} | |
.dark .footer {border-color: #303030} | |
.dark .footer>p {background: #0b0f19} | |
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%} | |
@keyframes spin { | |
from { | |
transform: rotate(0deg); | |
} | |
to { | |
transform: rotate(360deg); | |
} | |
} | |
''' | |
demo = gr.Blocks(css=css, elem_id="total-container") | |
with demo: | |
gr.HTML(read_content("header.html")) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(mobile_collapse=False, equal_height=True): | |
prompt = gr.Textbox(placeholder="Your prompt", show_label=False, elem_id="prompt", autofocus=True, lines=3, ) | |
with gr.Accordion(label="Advanced Settings", open=False): | |
with gr.Row(mobile_collapse=False, equal_height=True): | |
steps = gr.Slider(value=30, minimum=15, maximum=50, step=1, label="steps", interactive=True) | |
seed = gr.Slider(value=42, minimum=1, maximum=100, step=1, label="seed", interactive=True) | |
guidance_scale = gr.Slider(value=7.5, minimum=1, maximum=15, step=0.1, label='guidance_scale', interactive=True) | |
with gr.Row(mobile_collapse=False, equal_height=True): | |
negative_prompt = gr.Textbox(label="negative_prompt", placeholder="Your negative prompt", | |
info="what you don't want to see in the image", lines=3) | |
with gr.Row(): | |
btn = gr.Button(value="Generate!", elem_id="run_button") | |
with gr.Column(): | |
image_out = gr.Image(label="Output", elem_id="output-img", height=400) | |
btn.click(fn=predict, | |
inputs=[prompt, steps, seed, guidance_scale, negative_prompt], | |
outputs=[image_out], | |
api_name='run') | |
gr.HTML( | |
""" | |
<div class="footer"> | |
<p>Model by <a href="https://deci.ai" style="text-decoration: underline;" target="_blank">Deci.ai</a> - Gradio Demo by 🤗 Hugging Face | |
</p> | |
</div> | |
""" | |
) | |
demo.queue(max_size=50).launch() | |