sagar007 commited on
Commit
1042ff4
1 Parent(s): 45e862a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -9
app.py CHANGED
@@ -13,8 +13,11 @@ else:
13
 
14
  # Initialize the base model and specific LoRA
15
  base_model = "black-forest-labs/FLUX.1-dev"
16
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float32)
17
- pipe.to("gpu")
 
 
 
18
 
19
  lora_repo = "sagar007/sagar_flux"
20
  trigger_word = "sagar"
@@ -25,10 +28,8 @@ MAX_SEED = 2**32-1
25
  def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
26
  if randomize_seed:
27
  seed = random.randint(0, MAX_SEED)
28
- generator = torch.Generator(device="gpu").manual_seed(seed)
29
-
30
- progress(0, "Starting image generation (this may take a while on CPU)...")
31
-
32
  image = pipe(
33
  prompt=f"{prompt} {trigger_word}",
34
  num_inference_steps=steps,
@@ -38,12 +39,34 @@ def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora
38
  generator=generator,
39
  cross_attention_kwargs={"scale": lora_scale},
40
  ).images[0]
41
-
42
  progress(100, "Completed!")
43
-
44
  return image, seed
45
 
46
- # (Rest of the Gradio interface code remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # Launch the app
49
  app.launch()
 
13
 
14
  # Initialize the base model and specific LoRA
15
  base_model = "black-forest-labs/FLUX.1-dev"
16
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16)
17
+
18
+ # Check if CUDA is available and move the model to GPU if possible
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ pipe = pipe.to(device)
21
 
22
  lora_repo = "sagar007/sagar_flux"
23
  trigger_word = "sagar"
 
28
  def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
29
  if randomize_seed:
30
  seed = random.randint(0, MAX_SEED)
31
+ generator = torch.Generator(device=device).manual_seed(seed)
32
+ progress(0, f"Starting image generation (using {device})...")
 
 
33
  image = pipe(
34
  prompt=f"{prompt} {trigger_word}",
35
  num_inference_steps=steps,
 
39
  generator=generator,
40
  cross_attention_kwargs={"scale": lora_scale},
41
  ).images[0]
 
42
  progress(100, "Completed!")
 
43
  return image, seed
44
 
45
+ # Gradio interface setup
46
+ with gr.Blocks() as app:
47
+ gr.Markdown("# Text-to-Image Generation with LoRA")
48
+ with gr.Row():
49
+ with gr.Column():
50
+ prompt = gr.Textbox(label="Prompt")
51
+ run_button = gr.Button("Generate")
52
+ with gr.Column():
53
+ result = gr.Image(label="Result")
54
+ with gr.Row():
55
+ cfg_scale = gr.Slider(minimum=1, maximum=20, value=7, step=0.1, label="CFG Scale")
56
+ steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Steps")
57
+ with gr.Row():
58
+ width = gr.Slider(minimum=128, maximum=1024, value=512, step=64, label="Width")
59
+ height = gr.Slider(minimum=128, maximum=1024, value=512, step=64, label="Height")
60
+ with gr.Row():
61
+ seed = gr.Number(label="Seed", precision=0)
62
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
63
+ lora_scale = gr.Slider(minimum=0, maximum=1, value=0.75, step=0.01, label="LoRA Scale")
64
+
65
+ run_button.click(
66
+ run_lora,
67
+ inputs=[prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
68
+ outputs=[result, seed]
69
+ )
70
 
71
  # Launch the app
72
  app.launch()