File size: 2,631 Bytes
c83e1ca
986d247
f0e2bd1
 
9234004
 
 
 
 
 
 
 
f0e2bd1
 
9234004
1042ff4
 
 
 
 
f0e2bd1
 
9234004
f0e2bd1
 
 
 
 
 
 
1042ff4
 
f0e2bd1
 
 
 
 
 
 
9234004
f0e2bd1
 
9234004
f0e2bd1
1042ff4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c83e1ca
9234004
c83e1ca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
import torch
from diffusers import DiffusionPipeline
import random
from huggingface_hub import login
import os

# Authenticate using the token stored in Hugging Face Spaces secrets
if 'HF_TOKEN' in os.environ:
    login(token=os.environ['HF_TOKEN'])
else:
    raise ValueError("HF_TOKEN not found in environment variables. Please add it to your Space's secrets.")

# Initialize the base model and specific LoRA
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16)

# Check if CUDA is available and move the model to GPU if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe = pipe.to(device)

lora_repo = "sagar007/sagar_flux"
trigger_word = "sagar"
pipe.load_lora_weights(lora_repo)

MAX_SEED = 2**32-1

def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device=device).manual_seed(seed)
    progress(0, f"Starting image generation (using {device})...")
    image = pipe(
        prompt=f"{prompt} {trigger_word}",
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        width=width,
        height=height,
        generator=generator,
        cross_attention_kwargs={"scale": lora_scale},
    ).images[0]
    progress(100, "Completed!")
    return image, seed

# Gradio interface setup
with gr.Blocks() as app:
    gr.Markdown("# Text-to-Image Generation with LoRA")
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt")
            run_button = gr.Button("Generate")
        with gr.Column():
            result = gr.Image(label="Result")
    with gr.Row():
        cfg_scale = gr.Slider(minimum=1, maximum=20, value=7, step=0.1, label="CFG Scale")
        steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Steps")
    with gr.Row():
        width = gr.Slider(minimum=128, maximum=1024, value=512, step=64, label="Width")
        height = gr.Slider(minimum=128, maximum=1024, value=512, step=64, label="Height")
    with gr.Row():
        seed = gr.Number(label="Seed", precision=0)
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
    lora_scale = gr.Slider(minimum=0, maximum=1, value=0.75, step=0.01, label="LoRA Scale")
    
    run_button.click(
        run_lora,
        inputs=[prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
        outputs=[result, seed]
    )

# Launch the app
app.launch()