File size: 6,014 Bytes
5252609
827c867
5252609
925b9c1
b4808d1
2330a16
b3925d0
dae60c0
 
5252609
5abbc9d
 
ea450aa
 
 
 
 
 
 
925b9c1
b3925d0
5252609
b3925d0
 
72f6034
5252609
 
ea450aa
dae60c0
 
 
 
84fe6c1
 
 
 
 
5252609
84fe6c1
 
dae60c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359047c
 
 
dae60c0
 
 
988a34b
827c867
 
 
 
 
 
dae60c0
828375c
988a34b
84fe6c1
 
827c867
dae60c0
828375c
84fe6c1
5252609
84fe6c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ca9ce3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
from diffusers import StableDiffusion3Pipeline
import gradio as gr
import os
import random
import transformers
import numpy as np
from transformers import T5Tokenizer, T5ForConditionalGeneration
import spaces

HF_TOKEN = os.getenv("HF_TOKEN")

if torch.cuda.is_available():
    device = "cuda"
    print("Using GPU")
else:
    device = "cpu"
    print("Using CPU")


MAX_SEED = np.iinfo(np.int32).max


# Initialize the pipeline and download the sd3 medium model
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe.to(device)

# superprompt-v1
tokenizer = T5Tokenizer.from_pretrained("roborovski/superprompt-v1")
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", device_map="auto", torch_dtype="auto")
model.to(device)

# toggle visibility the enhanced prompt output
def update_visibility(enhance_prompt):
    return gr.update(visible=enhance_prompt)


# Define the image generation function
@spaces.GPU(duration=80)
def generate_image(prompt, enhance_prompt, negative_prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt, progress=gr.Progress(track_tqdm=True)):
    if seed == 0:
        seed = random.randint(1, 2**32-1)
        
    if enhance_prompt:
        transformers.set_seed(seed)
        
        input_text = f"Expand the following prompt to add more detail: {prompt}"
        input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
        
        outputs = model.generate(
        input_ids,
        max_new_tokens=512,
        repetition_penalty=1.2,
        do_sample=True,
        temperature=0.7,
        top_p=1,
        top_k=50
        )
        prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
    generator = torch.Generator().manual_seed(seed)
    
    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        height=height,
        width=width,
        guidance_scale=guidance_scale,
        generator=generator,
        num_images_per_prompt=num_images_per_prompt
    ).images
    
    return output, prompt



# Create the Gradio interface

examples = [
    ["A white car racing fast to the moon.", True],
    ["A woman in a red dress singing on top of a building.", True],
    ["An astrounat on mars in a futuristic cyborg suit.", True],
]

css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
'''
with gr.Blocks(css=css) as demo:
    with gr.Row():
        with gr.Column():
            gr.HTML(
            """
            <h1 style='text-align: center'>
            Stable Diffusion 3 Medium Superprompt
            </h1>
            """
        )
            gr.HTML(
                """
               Made by <a href='https://linktr.ee/Nick088' target='_blank'>Nick088</a>
               <br> <a href="https://discord.gg/osai"> <img src="https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge" alt="Discord"> </a>
                """
        )
    with gr.Group():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", info="Describe the image you want", placeholder="A cat...")
            enhance_prompt = gr.Checkbox(label="Prompt Enhancement with SuperPrompt-v1", value=True)
            run_button = gr.Button("Run")
        result = gr.Gallery(label="Generated AI Images", elem_id="gallery")
        better_prompt = gr.Textbox(label="Enhanced Prompt", info="The output of your enhanced prompt used for the Image Generation", visible=True)
        enhance_prompt.change(fn=update_visibility, inputs=enhance_prompt, outputs=better_prompt)
    with gr.Accordion("Advanced options", open=False):
        with gr.Row():
            negative_prompt = gr.Textbox(label="Negative Prompt", info="Describe what you don't want in the image", value="deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation", placeholder="Ugly, bad anatomy...")
        with gr.Row():
            num_inference_steps = gr.Slider(label="Number of Inference Steps", info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference", minimum=1, maximum=50, value=25, step=1)
            guidance_scale = gr.Slider(label="Guidance Scale", info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.", minimum=0.0, maximum=10.0, value=7.5, step=0.1)
        with gr.Row():
            width = gr.Slider(label="Width", info="Width of the Image", minimum=256, maximum=1344, step=32, value=1024)
            height = gr.Slider(label="Height", info="Height of the Image", minimum=256, maximum=1344, step=32, value=1024)
        with gr.Row():
            seed = gr.Slider(value=42, minimum=0, maximum=MAX_SEED, step=1, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
            num_images_per_prompt = gr.Slider(label="Images Per Prompt", info="Number of Images to generate with the settings",minimum=1, maximum=4, step=1, value=2)

    gr.Examples(
        examples=examples,
        inputs=[prompt, enhance_prompt],
        outputs=[result, better_prompt],
        fn=generate_image,
    )

    gr.on(
        triggers=[
            prompt.submit,
            run_button.click,
        ],
        fn=generate_image,
        inputs=[
            prompt,
            enhance_prompt,
            negative_prompt,
            num_inference_steps,
            width,
            height,
            guidance_scale,
            seed,
            num_images_per_prompt,
        ],
        outputs=[result, better_prompt],
    )

demo.queue().launch(share = False)