File size: 3,905 Bytes
5252609
827c867
5252609
925b9c1
2330a16
dae60c0
5abbc9d
dae60c0
5252609
5abbc9d
 
ea450aa
 
 
 
 
 
 
 
5abbc9d
 
 
 
 
 
 
 
925b9c1
5252609
 
5abbc9d
5252609
 
ea450aa
dae60c0
 
 
 
5252609
 
ea450aa
dae60c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
988a34b
827c867
 
 
 
 
 
dae60c0
828375c
988a34b
 
5252609
 
827c867
 
 
dae60c0
 
827c867
 
 
 
828375c
827c867
828375c
827c867
 
 
dae60c0
 
828375c
 
5252609
 
dae60c0
54f830c
5252609
988a34b
5252609
 
 
e8e6d43
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from diffusers import StableDiffusion3Pipeline
import gradio as gr
import os
import transformers
from transformers import T5Tokenizer, T5ForConditionalGeneration
from huggingface_hub import snapshot_download
import spaces

HF_TOKEN = os.getenv("HF_TOKEN")

if torch.cuda.is_available():
    device = "cuda"
    print("Using GPU")
else:
    device = "cpu"
    print("Using CPU")

# download sd3 medium weights
model_path = snapshot_download(
    repo_id="stabilityai/stable-diffusion-3-medium", 
    revision="refs/pr/26",
    repo_type="model", 
    ignore_patterns=["*.md", "*..gitattributes"],
    local_dir="stable-diffusion-3-medium",
    token=HF_TOKEN,
    )


# Initialize the pipeline and download the model
pipe = StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to(device)

# superprompt-v1
tokenizer = T5Tokenizer.from_pretrained("roborovski/superprompt-v1")
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", device_map="auto", torch_dtype="auto")
model.to(device)

# Define the image generation function
@spaces.GPU(duration=60)
def generate_image(prompt, enhance_prompt, negative_prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt):
    if seed == 0:
        seed = random.randint(1, 2**32-1)
        
    if enhance_prompt:
        transformers.set_seed(seed)
        
        input_text = f"Expand the following prompt to add more detail: {prompt}"
        input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
        
        outputs = model.generate(
        input_ids,
        max_new_tokens=512,
        repetition_penalty=1.2,
        do_sample=True,
        temperature=0.7,
        top_p=1,
        top_k=50,
    )
        
    generator = torch.Generator().manual_seed(seed)
    
    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        height=height,
        width=width,
        guidance_scale=guidance_scale,
        generator=generator,
        num_images_per_prompt=num_images_per_prompt
    ).images
    return output

# Create the Gradio interface

prompt = gr.Textbox(label="Prompt", info="Describe the image you want", placeholder="A cat...")

enhance_prompt = gr.Checkbox(label="Prompt Enhancement", info="Enhance your prompt with SuperPrompt-v1", value=True)

negative_prompt = gr.Textbox(label="Negative Prompt", info="Describe what you don't want in the image", placeholder="Ugly, bad anatomy...")

num_inference_steps = gr.Number(label="Number of Inference Steps", precision=0, value=25)

height = gr.Slider(label="Height", info="Height of the Image", minimum=256, maximum="1536", step=32, value=1024)

width = gr.Slider(label="Width", info="Width of the Image", minimum=256, maximum="1536", step=32, value=1024)

guidance_scale = gr.Number(minimum=0.1, value=7.5, label="Guidance Scale", info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference")

seed = gr.Slider(value=42, minimum=0, maximum=2**32-1, step=1, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")

num_images_per_prompt = gr.Slider(label="Number of Images to generate with the settings",minimum=1, maximum=4, step=1, value=1)

interface = gr.Interface(
    fn=generate_image,
    inputs=[prompt, enhance_prompt, negative_prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt],
    outputs=gr.Gallery(label="Generated AI Images", elem_id="gallery", show_label=False),
    title="Stable Diffusion 3 Medium",
    description="Made by <a href='https://linktr.ee/Nick088' target='_blank'>Nick088</a> \n Join https://discord.gg/osai to talk about Open Source AI"
)

# Launch the interface
interface.launch(share = False)