AlekseyCalvin commited on
Commit
b203bc9
·
verified ·
1 Parent(s): 08faa09

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -0
app.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import logging
4
+ import argparse
5
+ import torch
6
+ import transformers
7
+ import os
8
+ from os import path
9
+ from PIL import Image
10
+ import spaces
11
+ import copy
12
+ import random
13
+ import time
14
+ from huggingface_hub import hf_hub_download
15
+ from diffusers import FluxTransformer2DModel, FluxPipeline
16
+ import safetensors.torch
17
+ from safetensors.torch import load_file
18
+ from transformers import CLIPModel, CLIPProcessor, CLIPConfig
19
+ import gc
20
+
21
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
22
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
23
+ os.environ["HF_HUB_CACHE"] = cache_path
24
+ os.environ["HF_HOME"] = cache_path
25
+
26
+ clipmodel = 'long' # 'norm', 'long' (my fine-tunes) - 'oai', 'orgL' (OpenAI / BeichenZhang original)
27
+ selectedprompt = 'long' # 'tiny' (51 tokens), 'short' (75), 'med' (116), 'long' (203)
28
+
29
+ if clipmodel == "long":
30
+ model_id = "zer0int/LongCLIP-GmP-ViT-L-14"
31
+ config = CLIPConfig.from_pretrained(model_id)
32
+ maxtokens = 248
33
+
34
+ torch.backends.cuda.matmul.allow_tf32 = True
35
+
36
+ clip_model = CLIPModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, config=config).to(device)
37
+ clip_processor = CLIPProcessor.from_pretrained(model_id, padding="max_length", max_length=maxtokens, return_tensors="pt", truncation=True)
38
+ config.text_config.max_position_embeddings = 248
39
+
40
+
41
+ pipe = FluxPipeline.from_pretrained("AlekseyCalvin/HistoricColorSoonr_v2_FluxSchnell_Diffusers", torch_dtype=torch.bfloat16)
42
+ pipe.to(device="cuda", dtype=torch.bfloat16)
43
+
44
+ pipe.tokenizer = clip_processor.tokenizer
45
+ pipe.text_encoder = clip_model.text_model
46
+ pipe.tokenizer_max_length = maxtokens
47
+ pipe.text_encoder.dtype = torch.bfloat16
48
+
49
+
50
+ # Load LoRAs from JSON file
51
+ with open('loras.json', 'r') as f:
52
+ loras = json.load(f)
53
+
54
+ MAX_SEED = 2**32-1
55
+
56
+ class calculateDuration:
57
+ def __init__(self, activity_name=""):
58
+ self.activity_name = activity_name
59
+
60
+ def __enter__(self):
61
+ self.start_time = time.time()
62
+ return self
63
+
64
+ def __exit__(self, exc_type, exc_value, traceback):
65
+ self.end_time = time.time()
66
+ self.elapsed_time = self.end_time - self.start_time
67
+ if self.activity_name:
68
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
69
+ else:
70
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
71
+
72
+
73
+ def update_selection(evt: gr.SelectData, width, height):
74
+ selected_lora = loras[evt.index]
75
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
76
+ lora_repo = selected_lora["repo"]
77
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
78
+ if "aspect" in selected_lora:
79
+ if selected_lora["aspect"] == "portrait":
80
+ width = 768
81
+ height = 1024
82
+ elif selected_lora["aspect"] == "landscape":
83
+ width = 1024
84
+ height = 768
85
+ return (
86
+ gr.update(placeholder=new_placeholder),
87
+ updated_text,
88
+ evt.index,
89
+ width,
90
+ height,
91
+ )
92
+
93
+ @spaces.GPU(duration=70)
94
+ def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress):
95
+ pipe.to("cuda")
96
+ generator = torch.Generator(device="cuda").manual_seed(seed)
97
+
98
+ with calculateDuration("Generating image"):
99
+ # Generate image
100
+ image = pipe(
101
+ prompt=f"{prompt} {trigger_word}",
102
+ num_inference_steps=steps,
103
+ guidance_scale=cfg_scale,
104
+ width=width,
105
+ height=height,
106
+ generator=generator,
107
+ joint_attention_kwargs={"scale": lora_scale},
108
+ ).images[0]
109
+ return image
110
+
111
+ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
112
+ if selected_index is None:
113
+ raise gr.Error("You must select a LoRA before proceeding.")
114
+
115
+ selected_lora = loras[selected_index]
116
+ lora_path = selected_lora["repo"]
117
+ trigger_word = selected_lora["trigger_word"]
118
+ if(trigger_word):
119
+ if "trigger_position" in selected_lora:
120
+ if selected_lora["trigger_position"] == "prepend":
121
+ prompt_mash = f"{trigger_word} {prompt}"
122
+ else:
123
+ prompt_mash = f"{prompt} {trigger_word}"
124
+ else:
125
+ prompt_mash = f"{trigger_word} {prompt}"
126
+ else:
127
+ prompt_mash = prompt
128
+
129
+ # Load LoRA weights
130
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
131
+ if "weights" in selected_lora:
132
+ pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
133
+ else:
134
+ pipe.load_lora_weights(lora_path)
135
+
136
+ # Set random seed for reproducibility
137
+ with calculateDuration("Randomizing seed"):
138
+ if randomize_seed:
139
+ seed = random.randint(0, MAX_SEED)
140
+
141
+ image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
142
+ pipe.to("cpu")
143
+ pipe.unload_lora_weights()
144
+ return image, seed
145
+
146
+ run_lora.zerogpu = True
147
+
148
+ css = '''
149
+ #gen_btn{height: 100%}
150
+ #title{text-align: center}
151
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
152
+ #title img{width: 100px; margin-right: 0.5em}
153
+ #gallery .grid-wrap{height: 10vh}
154
+ '''
155
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
156
+ title = gr.HTML(
157
+ """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> SOONfactory </h1>""",
158
+ elem_id="title",
159
+ )
160
+ # Info blob stating what the app is running
161
+ info_blob = gr.HTML(
162
+ """<div id="info_blob"> Activist & Futurealist LoRa-stocked Img Manufactory (currently on our Historic Color Soon®v.2 Flux Schnell (2-8 steps) model checkpoint (at AlekseyCalvin/HistoricColorSoonrFluxV2) )</div>"""
163
+ )
164
+
165
+ # Info blob stating what the app is running
166
+ info_blob = gr.HTML(
167
+ """<div id="info_blob">Prephrase prompts w/: 1-3. HST style |4. RCA style Communist poster |5. TOK hybrid |6. 2004 photo |7. HST style |8. LEN Vladimir Lenin |9. TOK portra |10. HST portrait |11. flmft |12. HST in Peterhof |13. HST Soviet kodachrome |14. SOTS art |15. HST Austin Osman Spare style |16. yearbook photo |17. pficonics |18. wh3r3sw4ld0 |19. retrofuturism |20. crisp |21-29. HST style photo |30. photo shot on a phone |31. unexpected photo of |32. propaganda poster of |33. Marina TSVETAEVA |34. Alexander BLOK |35. ROSA Luxemburg |36. Leon TROTSKY |37. vintage cover </div>"""
168
+ )
169
+ selected_index = gr.State(None)
170
+ with gr.Row():
171
+ with gr.Column(scale=3):
172
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Select LoRa/Style & type prompt!")
173
+ with gr.Column(scale=1, elem_id="gen_column"):
174
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
175
+ with gr.Row():
176
+ with gr.Column(scale=3):
177
+ selected_info = gr.Markdown("")
178
+ gallery = gr.Gallery(
179
+ [(item["image"], item["title"]) for item in loras],
180
+ label="LoRA Inventory",
181
+ allow_preview=False,
182
+ columns=3,
183
+ elem_id="gallery"
184
+ )
185
+
186
+ with gr.Column(scale=4):
187
+ result = gr.Image(label="Generated Image")
188
+
189
+ with gr.Row():
190
+ with gr.Accordion("Advanced Settings", open=True):
191
+ with gr.Column():
192
+ with gr.Row():
193
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=1.0)
194
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=6)
195
+
196
+ with gr.Row():
197
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=768)
198
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
199
+
200
+ with gr.Row():
201
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
202
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
203
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2.0, step=0.01, value=0.6)
204
+
205
+ gallery.select(
206
+ update_selection,
207
+ inputs=[width, height],
208
+ outputs=[prompt, selected_info, selected_index, width, height]
209
+ )
210
+
211
+ gr.on(
212
+ triggers=[generate_button.click, prompt.submit],
213
+ fn=run_lora,
214
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
215
+ outputs=[result, seed]
216
+ )
217
+
218
+ app.queue(default_concurrency_limit=2).launch(show_error=True)
219
+ app.launch()