File size: 8,598 Bytes
f9aa991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f66e410
f9aa991
 
 
 
 
 
 
 
 
 
 
 
 
47bc811
f66e410
e8d1004
f7cf3c7
47bc811
f9aa991
846b367
 
 
 
 
 
 
 
 
 
9c46709
 
 
 
 
 
f7cf3c7
9c46709
 
 
 
 
 
 
 
 
 
 
 
 
 
846b367
f9aa991
 
846b367
05693ff
 
 
a8ac5c3
05693ff
 
 
846b367
05693ff
d3260f5
f9aa991
 
 
 
 
 
 
846b367
e73313a
f9aa991
 
 
 
 
 
 
 
 
 
846b367
2e0a09e
f66e410
e73313a
f66e410
 
 
 
ed61e3f
f66e410
f9aa991
 
 
 
 
 
6557239
907f5c4
14b40a7
f66e410
ba366b3
 
e73313a
f66e410
 
6557239
ba366b3
c109868
f66e410
ba366b3
907f5c4
 
f66e410
 
 
 
 
 
 
 
f9aa991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f66e410
f9aa991
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import spaces
import argparse
import os
import time
from os import path
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path

import gradio as gr
import torch
from diffusers import FluxPipeline

torch.backends.cuda.matmul.allow_tf32 = True
loaded_acc = None
class timer:
    def __init__(self, method_name="timed process"):
        self.method = method_name
    def __enter__(self):
        self.start = time.time()
        print(f"{self.method} starts")
    def __exit__(self, exc_type, exc_val, exc_tb):
        end = time.time()
        print(f"{self.method} took {str(round(end - self.start, 2))}s")

if not path.exists(cache_path):
    os.makedirs(cache_path, exist_ok=True)

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("RED-AIGC/TDD", "FLUX.1-dev_tdd_lora_weights.safetensors"),adapter_name="TDD")
pipe.load_lora_weights(hf_hub_download("RED-AIGC/TDD", "FLUX.1-dev_tdd_adv_lora_weights.safetensors"),adapter_name="TDD_adv")
# pipe.fuse_lora(lora_scale=0.125)
pipe.to("cuda")

css = """
h1 {
    text-align: center;
    display:block;
}
.gradio-container {
  max-width: 70.5rem !important;
}
"""

@spaces.GPU
def process_image(prompt,acc,height, width, steps, scales, seed):
    global pipe
    global loaded_acc
    if loaded_acc != acc:
        #pipe.load_lora_weights(ACC_lora[acc], adapter_name=acc)
        pipe.set_adapters([acc], adapter_weights=[0.125])
        print(pipe.get_active_adapters())
        loaded_acc = acc
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
        return pipe(
            prompt=[prompt],
            generator=torch.Generator().manual_seed(int(seed)),
            num_inference_steps=int(steps),
            guidance_scale=float(scales),
            height=int(height),
            width=int(width),
            max_sequence_length=256
        ).images[0]


with gr.Blocks(css=css) as demo:
    gr.Markdown(
        """
        # FLUX.1-dev(beta) distilled by ✨Target-Driven Distillation✨
        
        Compared to Hyper-FLUX, the beta version of TDD has its parameters reduced by half(600MB), resulting in more realistic details. 
        
        Due to limitations in machine resources, there are still many imperfections in the beta version. 
        
        Besides, TDD is also available for distilling video generation models. This space presents TDD-distilled [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev).
        
        [**Project Page**](https://redaigc.github.io/TDD/) **|** [**Paper**](https://arxiv.org/abs/2409.01347) **|** [**Code**](https://github.com/RedAIGC/Target-Driven-Distillation) **|** [**Model**](https://huggingface.co/RED-AIGC/TDD) **|** [🤗 **TDD-SDXL Demo**](https://huggingface.co/spaces/RED-AIGC/TDD) **|** [🤗 **TDD-SVD Demo**](https://huggingface.co/spaces/RED-AIGC/SVD-TDD)
        
        The codes of this space are built on [Hyper-FLUX](https://huggingface.co/spaces/ByteDance/Hyper-FLUX-8Steps-LoRA) and we acknowledge their contribution.
        """
    )

    with gr.Row():
        with gr.Column(scale=3):
            with gr.Group():
                prompt = gr.Textbox(
                    label="Prompt",
                    # value="portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
                    lines=3
                )
                
                with gr.Accordion("Advanced Settings", open=False):
                    with gr.Group():
                        with gr.Row():
                            height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
                            width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
                        
                        with gr.Row():
                            steps = gr.Slider(label="Inference Steps", minimum=4, maximum=10, step=1, value=8)
                            scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=2.0)
                        with gr.Row():
                            seed = gr.Number(label="Seed", value=-1, precision=0)
                        with gr.Row():
                            acc = gr.Dropdown(
                                label="Accelerate Lora",
                                choices=["TDD", "TDD_adv"],
                                value="TDD_adv",
                            )
                
                generate_btn = gr.Button("Generate Image", variant="primary", scale=1)

        with gr.Column(scale=4):
            output = gr.Image(label="Your Generated Image")
    
    person1="A young woman with strikingly symmetrical features, smooth porcelain skin, and large hazel eyes framed by thick, wavy auburn hair, sitting beside a window where soft sunlight highlights her high cheekbones and glossy lips."
    person2="A radiant woman with soft, youthful skin, her lips lightly tinted in a natural pink, and her large, round eyes framed by thick lashes, posing gracefully in a traditional kimono among red maple leaves."
    person3="A modern woman with striking blue eyes, a perfectly symmetrical face, and soft, glossy lips, sitting in a sleek interior with soft, natural light highlighting her flawless skin."
    dog="Portrait photo of a Shiba Inu, photograph, highly detailed fur, warm and cheerful light, soft pastel tones, vibrant and sunny atmosphere, style by Tim Flach, bright daylight, natural setting, centered, extremely detailed, Nikon D850, award-winning photography"
    fox="A majestic Arctic fox standing gracefully on a snowy tundra, its thick white fur blending seamlessly with the icy surroundings. The texture of the fox's fur is highly detailed, showcasing its fluffy and warm coat. "
    scenery1="A quiet autumn street lined with maple trees, their vibrant red and orange leaves falling gently to the ground. The pavement is partially covered with scattered leaves, while a soft breeze stirs the branches. Sunlight filters through the canopy, casting warm, dappled light on the sidewalk. A lone cyclist rides down the street, and a few pedestrians stroll along, enjoying the crisp autumn air. The scene captures the serene beauty and subtle movement of an autumn day."
    scenery2="A high-detail close-up of the Forbidden City's architectural details, with snowflakes gracefully falling. Shot on a Nikon D850, the 45.7-megapixel sensor captures the scene's bright warmth. A fast shutter speed at a low ISO freezes the snowflakes, ensuring a crisp, vivid image. Ideal for a winter magazine feature."
    gr.Examples(
        examples=[
            [person1, "TDD_adv", 1024, 1024, 8, 2.1, 5685],
            [person2, "TDD_adv", 1024, 1024, 8, 2, 8888],
            [person3,"TDD_adv", 1024, 1024, 6, 1.8, 3420],
            [dog, "TDD", 1024, 1024, 6, 2, 29],
            [fox,"TDD_adv", 1024, 1024, 8, 2.1, 4678],
            [scenery1, "TDD_adv", 1024, 1024, 8, 2.5, 9669],
            [scenery2, "TDD_adv", 1024, 1024, 6, 1.8, 3420],
        ],
        # inputs=[prompt, negative_prompt, ckpt, acc, steps, guidance_scale, eta, seed],
        inputs=[prompt,acc, height, width, steps, scales, seed],
        outputs=output,
        fn=process_image,
        cache_examples="lazy",
    )

    gr.Markdown(
        """
        <div style="max-width: 650px; margin: 2rem auto; padding: 1rem; border-radius: 10px; background-color: #f0f0f0;">
            <h2 style="font-size: 1.5rem; margin-bottom: 1rem;">How to Use</h2>
            <ol style="padding-left: 1.5rem;">
                <li>Enter a detailed description of the image you want to create.</li>
                <li>Adjust advanced settings if desired (tap to expand).</li>
                <li>Tap "Generate Image" and wait for your creation!</li>
            </ol>
            <p style="margin-top: 1rem; font-style: italic;">Tip: Be specific in your description for best results!</p>
        </div>
        """
    )

    generate_btn.click(
        process_image,
        inputs=[prompt, acc,height, width, steps, scales, seed],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch()