Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -13,8 +13,11 @@ else:
|
|
13 |
|
14 |
# Initialize the base model and specific LoRA
|
15 |
base_model = "black-forest-labs/FLUX.1-dev"
|
16 |
-
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.
|
17 |
-
|
|
|
|
|
|
|
18 |
|
19 |
lora_repo = "sagar007/sagar_flux"
|
20 |
trigger_word = "sagar"
|
@@ -25,10 +28,8 @@ MAX_SEED = 2**32-1
|
|
25 |
def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
|
26 |
if randomize_seed:
|
27 |
seed = random.randint(0, MAX_SEED)
|
28 |
-
generator = torch.Generator(device=
|
29 |
-
|
30 |
-
progress(0, "Starting image generation (this may take a while on CPU)...")
|
31 |
-
|
32 |
image = pipe(
|
33 |
prompt=f"{prompt} {trigger_word}",
|
34 |
num_inference_steps=steps,
|
@@ -38,12 +39,34 @@ def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora
|
|
38 |
generator=generator,
|
39 |
cross_attention_kwargs={"scale": lora_scale},
|
40 |
).images[0]
|
41 |
-
|
42 |
progress(100, "Completed!")
|
43 |
-
|
44 |
return image, seed
|
45 |
|
46 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
# Launch the app
|
49 |
app.launch()
|
|
|
13 |
|
14 |
# Initialize the base model and specific LoRA
|
15 |
base_model = "black-forest-labs/FLUX.1-dev"
|
16 |
+
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16)
|
17 |
+
|
18 |
+
# Check if CUDA is available and move the model to GPU if possible
|
19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
pipe = pipe.to(device)
|
21 |
|
22 |
lora_repo = "sagar007/sagar_flux"
|
23 |
trigger_word = "sagar"
|
|
|
28 |
def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
|
29 |
if randomize_seed:
|
30 |
seed = random.randint(0, MAX_SEED)
|
31 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
32 |
+
progress(0, f"Starting image generation (using {device})...")
|
|
|
|
|
33 |
image = pipe(
|
34 |
prompt=f"{prompt} {trigger_word}",
|
35 |
num_inference_steps=steps,
|
|
|
39 |
generator=generator,
|
40 |
cross_attention_kwargs={"scale": lora_scale},
|
41 |
).images[0]
|
|
|
42 |
progress(100, "Completed!")
|
|
|
43 |
return image, seed
|
44 |
|
45 |
+
# Gradio interface setup
|
46 |
+
with gr.Blocks() as app:
|
47 |
+
gr.Markdown("# Text-to-Image Generation with LoRA")
|
48 |
+
with gr.Row():
|
49 |
+
with gr.Column():
|
50 |
+
prompt = gr.Textbox(label="Prompt")
|
51 |
+
run_button = gr.Button("Generate")
|
52 |
+
with gr.Column():
|
53 |
+
result = gr.Image(label="Result")
|
54 |
+
with gr.Row():
|
55 |
+
cfg_scale = gr.Slider(minimum=1, maximum=20, value=7, step=0.1, label="CFG Scale")
|
56 |
+
steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Steps")
|
57 |
+
with gr.Row():
|
58 |
+
width = gr.Slider(minimum=128, maximum=1024, value=512, step=64, label="Width")
|
59 |
+
height = gr.Slider(minimum=128, maximum=1024, value=512, step=64, label="Height")
|
60 |
+
with gr.Row():
|
61 |
+
seed = gr.Number(label="Seed", precision=0)
|
62 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
63 |
+
lora_scale = gr.Slider(minimum=0, maximum=1, value=0.75, step=0.01, label="LoRA Scale")
|
64 |
+
|
65 |
+
run_button.click(
|
66 |
+
run_lora,
|
67 |
+
inputs=[prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
|
68 |
+
outputs=[result, seed]
|
69 |
+
)
|
70 |
|
71 |
# Launch the app
|
72 |
app.launch()
|