Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,014 Bytes
5252609 827c867 5252609 925b9c1 b4808d1 2330a16 b3925d0 dae60c0 5252609 5abbc9d ea450aa 925b9c1 b3925d0 5252609 b3925d0 72f6034 5252609 ea450aa dae60c0 84fe6c1 5252609 84fe6c1 dae60c0 359047c dae60c0 988a34b 827c867 dae60c0 828375c 988a34b 84fe6c1 827c867 dae60c0 828375c 84fe6c1 5252609 84fe6c1 4ca9ce3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import torch
from diffusers import StableDiffusion3Pipeline
import gradio as gr
import os
import random
import transformers
import numpy as np
from transformers import T5Tokenizer, T5ForConditionalGeneration
import spaces
HF_TOKEN = os.getenv("HF_TOKEN")
if torch.cuda.is_available():
device = "cuda"
print("Using GPU")
else:
device = "cpu"
print("Using CPU")
MAX_SEED = np.iinfo(np.int32).max
# Initialize the pipeline and download the sd3 medium model
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe.to(device)
# superprompt-v1
tokenizer = T5Tokenizer.from_pretrained("roborovski/superprompt-v1")
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", device_map="auto", torch_dtype="auto")
model.to(device)
# toggle visibility the enhanced prompt output
def update_visibility(enhance_prompt):
return gr.update(visible=enhance_prompt)
# Define the image generation function
@spaces.GPU(duration=80)
def generate_image(prompt, enhance_prompt, negative_prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt, progress=gr.Progress(track_tqdm=True)):
if seed == 0:
seed = random.randint(1, 2**32-1)
if enhance_prompt:
transformers.set_seed(seed)
input_text = f"Expand the following prompt to add more detail: {prompt}"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
outputs = model.generate(
input_ids,
max_new_tokens=512,
repetition_penalty=1.2,
do_sample=True,
temperature=0.7,
top_p=1,
top_k=50
)
prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
generator = torch.Generator().manual_seed(seed)
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
height=height,
width=width,
guidance_scale=guidance_scale,
generator=generator,
num_images_per_prompt=num_images_per_prompt
).images
return output, prompt
# Create the Gradio interface
examples = [
["A white car racing fast to the moon.", True],
["A woman in a red dress singing on top of a building.", True],
["An astrounat on mars in a futuristic cyborg suit.", True],
]
css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
'''
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column():
gr.HTML(
"""
<h1 style='text-align: center'>
Stable Diffusion 3 Medium Superprompt
</h1>
"""
)
gr.HTML(
"""
Made by <a href='https://linktr.ee/Nick088' target='_blank'>Nick088</a>
<br> <a href="https://discord.gg/osai"> <img src="https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge" alt="Discord"> </a>
"""
)
with gr.Group():
with gr.Column():
prompt = gr.Textbox(label="Prompt", info="Describe the image you want", placeholder="A cat...")
enhance_prompt = gr.Checkbox(label="Prompt Enhancement with SuperPrompt-v1", value=True)
run_button = gr.Button("Run")
result = gr.Gallery(label="Generated AI Images", elem_id="gallery")
better_prompt = gr.Textbox(label="Enhanced Prompt", info="The output of your enhanced prompt used for the Image Generation", visible=True)
enhance_prompt.change(fn=update_visibility, inputs=enhance_prompt, outputs=better_prompt)
with gr.Accordion("Advanced options", open=False):
with gr.Row():
negative_prompt = gr.Textbox(label="Negative Prompt", info="Describe what you don't want in the image", value="deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation", placeholder="Ugly, bad anatomy...")
with gr.Row():
num_inference_steps = gr.Slider(label="Number of Inference Steps", info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference", minimum=1, maximum=50, value=25, step=1)
guidance_scale = gr.Slider(label="Guidance Scale", info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.", minimum=0.0, maximum=10.0, value=7.5, step=0.1)
with gr.Row():
width = gr.Slider(label="Width", info="Width of the Image", minimum=256, maximum=1344, step=32, value=1024)
height = gr.Slider(label="Height", info="Height of the Image", minimum=256, maximum=1344, step=32, value=1024)
with gr.Row():
seed = gr.Slider(value=42, minimum=0, maximum=MAX_SEED, step=1, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
num_images_per_prompt = gr.Slider(label="Images Per Prompt", info="Number of Images to generate with the settings",minimum=1, maximum=4, step=1, value=2)
gr.Examples(
examples=examples,
inputs=[prompt, enhance_prompt],
outputs=[result, better_prompt],
fn=generate_image,
)
gr.on(
triggers=[
prompt.submit,
run_button.click,
],
fn=generate_image,
inputs=[
prompt,
enhance_prompt,
negative_prompt,
num_inference_steps,
width,
height,
guidance_scale,
seed,
num_images_per_prompt,
],
outputs=[result, better_prompt],
)
demo.queue().launch(share = False) |