import gradio as gr import requests import io import re import random import os from PIL import Image from datasets import load_dataset from huggingface_hub import login login(token=os.getenv("HF_READ_TOKEN")) API_URL = "https://api-inference.huggingface.co/models/openskyml/open-diffusion-v1" API_TOKEN = os.getenv("HF_READ_TOKEN") # it is free headers = {"Authorization": f"Bearer {API_TOKEN}"} word_list_dataset = load_dataset("openskyml/bad-words-prompt-list", data_files="en.txt", use_auth_token=True) word_list = word_list_dataset["train"]['text'] def query(prompt, is_negative=False, steps=7, cfg_scale=7, seed=None, num_images=4): for filter in word_list: if re.search(rf"\b{filter}\b", prompt): raise gr.Error("Unsafe content found. Please try again with different prompts.") images = [] for _ in range(num_images): payload = { "inputs": prompt + ", 8k", "is_negative": is_negative, "steps": steps, "cfg_scale": cfg_scale, "seed": seed if seed is not None else random.randint(-1, 2147483647) } image_bytes = requests.post(API_URL, headers=headers, json=payload).content image = Image.open(io.BytesIO(image_bytes)) images.append(image) return images css = """ #gallery { min-height: 22rem; margin-bottom: 15px; margin-left: auto; margin-right: auto; border-bottom-right-radius: .5rem !important; border-bottom-left-radius: .5rem !important; } #gallery>div>.h-full { min-height: 20rem; } """ with gr.Blocks(css=css) as demo: gr.HTML( """

Open Diffusion 1.0 Demo

""" ) with gr.Group(): with gr.Box(): with gr.Row(): with gr.Column(): gallery_output = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2], height="auto") with gr.Row(elem_id="prompt-container"): with gr.Column(): text_prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=1, elem_id="prompt-text-input").style(border=(True, False, True, True), rounded=(True, False, False, True), container=False) negative_prompt = gr.Textbox(show_label=False, placeholder="Enter a negative", max_lines=1, elem_id="negative-prompt-text-input").style(border=(True, False, True, True), rounded=(True, False, False, True), container=False) text_button = gr.Button("Generate").style(margin=False, rounded=(False, True, True, False), full_width=False) text_button.click(query, inputs=[text_prompt, negative_prompt], outputs=gallery_output) demo.launch(show_api=False)