gokaygokay commited on
Commit
47b9af6
·
verified ·
1 Parent(s): 8a7a560

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -102
app.py CHANGED
@@ -1,15 +1,8 @@
1
  import spaces
2
  import gradio as gr
3
  import torch
4
- from PIL import Image
5
- from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
6
- from diffusers import DiffusionPipeline
7
  import random
8
- import numpy as np
9
- import os
10
- import subprocess
11
-
12
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
13
 
14
  # Initialize models
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -17,74 +10,36 @@ dtype = torch.bfloat16
17
 
18
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
19
 
20
- # FLUX.1-schnell model
21
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype, revision="refs/pr/1", token=huggingface_token).to(device)
22
-
23
- # Initialize Florence model
24
- florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
25
- florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
26
 
27
- # Prompt Enhancer
28
- enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
29
-
30
- MAX_SEED = np.iinfo(np.int32).max
31
- MAX_IMAGE_SIZE = 2048
32
-
33
- # Florence caption function
34
- def florence_caption(image):
35
- if not isinstance(image, Image.Image):
36
- image = Image.fromarray(image)
37
-
38
- inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
39
- generated_ids = florence_model.generate(
40
- input_ids=inputs["input_ids"],
41
- pixel_values=inputs["pixel_values"],
42
- max_new_tokens=1024,
43
- early_stopping=False,
44
- do_sample=False,
45
- num_beams=3,
46
- )
47
- generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
48
- parsed_answer = florence_processor.post_process_generation(
49
- generated_text,
50
- task="<MORE_DETAILED_CAPTION>",
51
- image_size=(image.width, image.height)
52
- )
53
- return parsed_answer["<MORE_DETAILED_CAPTION>"]
54
 
55
- # Prompt Enhancer function
56
- def enhance_prompt(input_prompt):
57
- result = enhancer_long("Enhance the description: " + input_prompt)
58
- enhanced_text = result[0]['summary_text']
59
- return enhanced_text
60
 
61
  @spaces.GPU(duration=75)
62
- def process_workflow(image, text_prompt, use_enhancer, seed, randomize_seed, width, height, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
63
- if image is not None:
64
- if not isinstance(image, Image.Image):
65
- image = Image.fromarray(image)
66
- prompt = florence_caption(image)
67
- else:
68
- prompt = text_prompt
69
-
70
- if use_enhancer:
71
- prompt = enhance_prompt(prompt)
72
-
73
- if randomize_seed:
74
- seed = random.randint(0, MAX_SEED)
75
-
76
- generator = torch.Generator(device=device).manual_seed(seed)
77
 
78
  image = pipe(
79
  prompt=prompt,
80
- generator=generator,
81
- num_inference_steps=num_inference_steps,
82
  width=width,
83
  height=height,
84
- guidance_scale=0.0
 
85
  ).images[0]
 
 
 
 
 
86
 
87
- return image, prompt, seed
 
88
 
89
  custom_css = """
90
  .input-group, .output-group {
@@ -103,47 +58,39 @@ custom_css = """
103
  }
104
  """
105
 
106
- title = """<h1 align="center">FLUX.1-schnell with Florence-2 Captioner and Prompt Enhancer</h1>
107
- <p><center>
108
- <a href="https://huggingface.co/black-forest-labs/FLUX.1-schnell" target="_blank">[FLUX.1-schnell Model]</a>
109
- <a href="https://huggingface.co/microsoft/Florence-2-base" target="_blank">[Florence-2 Model]</a>
110
- <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long" target="_blank">[Prompt Enhancer Long]</a>
111
- <p align="center">Create long prompts from images or enhance your short prompts with prompt enhancer</p>
112
- </center></p>
113
  """
114
 
115
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
116
  gr.HTML(title)
117
 
118
  with gr.Row():
119
- with gr.Column(scale=1):
120
- with gr.Group(elem_classes="input-group"):
121
- input_image = gr.Image(label="Input Image (Florence-2 Captioner)")
122
-
123
- with gr.Accordion("Advanced Settings", open=False):
124
- text_prompt = gr.Textbox(label="Text Prompt (optional, used if no image is uploaded)")
125
- use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
126
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
127
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
128
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
129
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
130
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=4)
131
-
132
- generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
133
-
134
- with gr.Column(scale=1):
135
- with gr.Group(elem_classes="output-group"):
136
- output_image = gr.Image(label="Result", elem_id="gallery", show_label=False)
137
- final_prompt = gr.Textbox(label="Final Prompt Used")
138
- used_seed = gr.Number(label="Seed Used")
139
 
140
- generate_btn.click(
141
- fn=process_workflow,
142
- inputs=[
143
- input_image, text_prompt, use_enhancer, seed, randomize_seed,
144
- width, height, num_inference_steps
145
- ],
146
- outputs=[output_image, final_prompt, used_seed]
147
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- demo.launch(debug=True)
 
1
  import spaces
2
  import gradio as gr
3
  import torch
 
 
 
4
  import random
5
+ from diffusers import DiffusionPipeline
 
 
 
 
6
 
7
  # Initialize models
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
10
 
11
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
12
 
13
+ # Initialize the base model and move it to GPU
14
+ base_model = "black-forest-labs/FLUX.1-dev"
15
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16, token=huggingface_token).to("cuda")
 
 
 
16
 
17
+ # Load LoRA weights
18
+ pipe.load_lora_weights("gokaygokay/Flux-Detailer-LoRA")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ MAX_SEED = 2**32-1
 
 
 
 
21
 
22
  @spaces.GPU(duration=75)
23
+ def generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale):
24
+ generator = torch.Generator(device="cuda").manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  image = pipe(
27
  prompt=prompt,
28
+ num_inference_steps=steps,
29
+ guidance_scale=cfg_scale,
30
  width=width,
31
  height=height,
32
+ generator=generator,
33
+ joint_attention_kwargs={"scale": lora_scale},
34
  ).images[0]
35
+ return image
36
+
37
+ def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale):
38
+ if randomize_seed:
39
+ seed = random.randint(0, MAX_SEED)
40
 
41
+ image = generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale)
42
+ return image, seed
43
 
44
  custom_css = """
45
  .input-group, .output-group {
 
58
  }
59
  """
60
 
61
+ title = """<h1 align="center">FLUX Creativity LoRA</h1>
 
 
 
 
 
 
62
  """
63
 
64
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray"), css=custom_css) as app:
65
  gr.HTML(title)
66
 
67
  with gr.Row():
68
+ prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Type your prompt here")
69
+
70
+ with gr.Row():
71
+ generate_button = gr.Button("Generate", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ with gr.Row():
74
+ result = gr.Image(label="Generated Image")
75
+
76
+ with gr.Accordion("Advanced Settings", open=False):
77
+ with gr.Row():
78
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
79
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
80
+
81
+ with gr.Row():
82
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
83
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
84
+
85
+ with gr.Row():
86
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
87
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
88
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
89
+
90
+ inputs = [prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale]
91
+ outputs = [result, seed]
92
+
93
+ generate_button.click(fn=run_lora, inputs=inputs, outputs=outputs)
94
+ prompt.submit(fn=run_lora, inputs=inputs, outputs=outputs)
95
 
96
+ app.launch(debug=True)