File size: 3,324 Bytes
2d6061e
 
 
2590cf7
2d6061e
 
 
43af989
499970a
 
 
2d6061e
 
 
 
 
71a5076
 
2d6061e
1b5d275
71a5076
 
 
2d6061e
66863fc
2d6061e
66863fc
2d6061e
 
 
 
 
 
66863fc
2d6061e
 
614d206
2d6061e
aad2eee
2d6061e
 
 
 
22deffa
 
 
2d6061e
 
 
 
 
 
 
 
 
 
ce3076c
 
34b2591
ce3076c
 
 
2d6061e
 
 
a4a0425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d6061e
ce3076c
5bf6550
b2977aa
96bbd5e
5bf6550
1e75837
 
 
 
2d6061e
 
520f6d6
614d206
2d6061e
4dd8b4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
import requests
import io
import re
import random
import os
from PIL import Image
from datasets import load_dataset
from huggingface_hub import login

login(token=os.getenv("HF_READ_TOKEN"))

API_URL = "https://api-inference.huggingface.co/models/openskyml/open-diffusion-v1"
API_TOKEN = os.getenv("HF_READ_TOKEN") # it is free
headers = {"Authorization": f"Bearer {API_TOKEN}"}

word_list_dataset = load_dataset("openskyml/bad-words-prompt-list", data_files="en.txt", use_auth_token=True)
word_list = word_list_dataset["train"]['text']

def query(prompt, is_negative=False, steps=5, cfg_scale=7, seed=None, num_images=4):
    for filter in word_list:
        if re.search(rf"\b{filter}\b", prompt):
            raise gr.Error("Unsafe content found. Please try again with different prompts.")
    images = []
    
    for _ in range(num_images):
        payload = {
                "inputs": prompt + ", 8k",
                "is_negative": is_negative,
                "steps": steps,
                "cfg_scale": cfg_scale,
                "seed": seed if seed is not None else random.randint(-1, 2147483647)
            }
        
        image_bytes = requests.post(API_URL, headers=headers, json=payload).content
        image = Image.open(io.BytesIO(image_bytes))
        
        images.append(image)

    return images


css = """
        .gradio-container {
            font-family: 'IBM Plex Sans', sans-serif;
        }
        #gallery {
            min-height: 22rem;
            margin-bottom: 15px;
            margin-left: auto;
            margin-right: auto;
            border-bottom-right-radius: .5rem !important;
            border-bottom-left-radius: .5rem !important;
        }
        #gallery>div>.h-full {
            min-height: 20rem;
        
        }
        
        #prompt-text-input, #negative-prompt-text-input{padding: .45rem 0.625rem}
        #component-16{border-top-width: 1px!important;margin-top: 1em}
        .image_duplication{position: absolute; width: 100px; left: 50px}
"""

with gr.Blocks(css=css) as demo:
    gr.HTML(
        """
            <div style="text-align: center; margin: 0 auto;">
              <div
                style="
                  display: inline-flex;
                  align-items: center;
                  gap: 0.8rem;
                  font-size: 1.75rem;
                "
              >
                <h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
                  Open Diffusion 1.0 Demo
                </h1>
              </div>
            </div>
        """
    )
        
    with gr.Row():           
        gallery_output = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2])
        
    with gr.Row():
        with gr.Box():
            text_prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=1)
            negative_prompt = gr.Textbox(show_label=False, placeholder="Enter a negative", max_lines=1)
            text_button = gr.Button("Generate", icon="https://www.gstatic.com/android/keyboard/emojikitchen/20210521/u1fa84/u1fa84_u1fa84.png")

        
        
    text_button.click(query, inputs=[text_prompt, negative_prompt], outputs=gallery_output)

demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)