File size: 3,324 Bytes
2d6061e 2590cf7 2d6061e 43af989 499970a 2d6061e 71a5076 2d6061e 1b5d275 71a5076 2d6061e 66863fc 2d6061e 66863fc 2d6061e 66863fc 2d6061e 614d206 2d6061e aad2eee 2d6061e 22deffa 2d6061e ce3076c 34b2591 ce3076c 2d6061e a4a0425 2d6061e ce3076c 5bf6550 b2977aa 96bbd5e 5bf6550 1e75837 2d6061e 520f6d6 614d206 2d6061e 4dd8b4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import gradio as gr
import requests
import io
import re
import random
import os
from PIL import Image
from datasets import load_dataset
from huggingface_hub import login
login(token=os.getenv("HF_READ_TOKEN"))
API_URL = "https://api-inference.huggingface.co/models/openskyml/open-diffusion-v1"
API_TOKEN = os.getenv("HF_READ_TOKEN") # it is free
headers = {"Authorization": f"Bearer {API_TOKEN}"}
word_list_dataset = load_dataset("openskyml/bad-words-prompt-list", data_files="en.txt", use_auth_token=True)
word_list = word_list_dataset["train"]['text']
def query(prompt, is_negative=False, steps=5, cfg_scale=7, seed=None, num_images=4):
for filter in word_list:
if re.search(rf"\b{filter}\b", prompt):
raise gr.Error("Unsafe content found. Please try again with different prompts.")
images = []
for _ in range(num_images):
payload = {
"inputs": prompt + ", 8k",
"is_negative": is_negative,
"steps": steps,
"cfg_scale": cfg_scale,
"seed": seed if seed is not None else random.randint(-1, 2147483647)
}
image_bytes = requests.post(API_URL, headers=headers, json=payload).content
image = Image.open(io.BytesIO(image_bytes))
images.append(image)
return images
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
#gallery {
min-height: 22rem;
margin-bottom: 15px;
margin-left: auto;
margin-right: auto;
border-bottom-right-radius: .5rem !important;
border-bottom-left-radius: .5rem !important;
}
#gallery>div>.h-full {
min-height: 20rem;
}
#prompt-text-input, #negative-prompt-text-input{padding: .45rem 0.625rem}
#component-16{border-top-width: 1px!important;margin-top: 1em}
.image_duplication{position: absolute; width: 100px; left: 50px}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<div style="text-align: center; margin: 0 auto;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
Open Diffusion 1.0 Demo
</h1>
</div>
</div>
"""
)
with gr.Row():
gallery_output = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2])
with gr.Row():
with gr.Box():
text_prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=1)
negative_prompt = gr.Textbox(show_label=False, placeholder="Enter a negative", max_lines=1)
text_button = gr.Button("Generate", icon="https://www.gstatic.com/android/keyboard/emojikitchen/20210521/u1fa84/u1fa84_u1fa84.png")
text_button.click(query, inputs=[text_prompt, negative_prompt], outputs=gallery_output)
demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860) |