File size: 3,683 Bytes
94cbfd9
 
8ccf632
 
 
 
a9f04e0
06f0278
94cbfd9
 
06f0278
 
8ccf632
d233b7e
bc0adb1
8ccf632
06f0278
94cbfd9
 
 
 
8ccf632
0e94382
94cbfd9
54192f0
 
94cbfd9
8ccf632
94cbfd9
8ccf632
1e787e4
 
 
 
 
94cbfd9
 
 
8ccf632
 
94cbfd9
8ccf632
94cbfd9
8ccf632
 
 
 
 
94cbfd9
 
 
8ccf632
 
94cbfd9
 
 
8ccf632
 
 
 
 
 
 
 
94cbfd9
8ccf632
94cbfd9
8ccf632
 
 
 
 
 
 
94cbfd9
8ccf632
 
 
 
 
 
94cbfd9
8ccf632
 
b213a9c
 
 
 
ceb48e8
94cbfd9
8ccf632
 
94cbfd9
8ccf632
 
 
94cbfd9
8ccf632
 
94cbfd9
 
 
 
8ccf632
 
 
 
2b62414
8ccf632
b213a9c
8ccf632
 
 
9aa8809
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import random

import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import  DiffusionPipeline, FlowMatchEulerDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
from gradio_imagefeed import ImageFeed


dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
LICENSE=f"""# Better UI for FLUX.1 [dev] [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)]"""
CSS = "#col-container { margin: 0 auto; max-width: 900px; }"
EXAMPLES = ["a tiny elephant hatching from a turtle egg in the palm of a human hand, highly detailed textures, close-up"]


@spaces.GPU(duration=45)
def infer(prompt, seed=99999, randomize_seed=True, width=896, height=1152, guidance_scale=5.0, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    generator = torch.Generator().manual_seed(seed)
    
    image = pipe(
        prompt = prompt, 
        width = width,
        height = height,
        num_inference_steps = num_inference_steps, 
        generator = generator,
        guidance_scale=guidance_scale).images[0]
    
    yield image, seed


with gr.Blocks(css=CSS) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(LICENSE)
        
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=5,
                placeholder="Prompt",
                container=False)
            
            run_button = gr.Button("Run", scale=0)

        result = ImageFeed(label="Result", show_label=False)
        # result = gr.Image(label="Result", show_label=False)
        
        with gr.Accordion("Advanced Settings", open=False):
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=random.randint(0, MAX_SEED))
            
            randomize_seed = gr.Checkbox(label="Randomize", value=True)
            
            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=896)
                
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1152)
            
            with gr.Row():
                guidance_scale = gr.Slider(
                    label="Guidance Scale",
                    minimum=1,
                    maximum=15,
                    step=0.1,
                    value=3)
  
                num_inference_steps = gr.Slider(
                    label="Inference Steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=28)
        
        gr.Examples(
            examples=EXAMPLES,
            fn=infer,
            inputs=[prompt],
            outputs=[result, seed],
            cache_examples="lazy"
        )

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn = infer,
        inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
        outputs = [result, seed]
    )

demo.launch()