John6666 commited on
Commit
dc1288a
1 Parent(s): 8c47eec

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -19
app.py CHANGED
@@ -4,7 +4,8 @@ import json
4
  import logging
5
  import torch
6
  from PIL import Image
7
- from diffusers import DiffusionPipeline
 
8
  from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel
9
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
10
  import copy
@@ -21,16 +22,29 @@ from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_
21
  from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
22
  from tagger.fl2flux import predict_tags_fl2_flux
23
 
 
 
 
 
 
 
 
24
  # Initialize the base model
25
  base_model = models[0]
26
  controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
27
  #controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
28
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
 
 
29
  controlnet_union = None
30
  controlnet = None
31
  last_model = models[0]
32
  last_cn_on = False
33
 
 
 
 
 
34
  # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
35
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
36
  def change_base_model(repo_id: str, cn_on: bool):
@@ -39,8 +53,6 @@ def change_base_model(repo_id: str, cn_on: bool):
39
  global controlnet
40
  global last_model
41
  global last_cn_on
42
- dtype = torch.bfloat16
43
- #dtype = torch.float8_e4m3fn
44
  try:
45
  if (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
46
  if cn_on:
@@ -50,6 +62,7 @@ def change_base_model(repo_id: str, cn_on: bool):
50
  controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)
51
  controlnet = FluxMultiControlNetModel([controlnet_union])
52
  pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype)
 
53
  last_model = repo_id
54
  last_cn_on = cn_on
55
  #progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
@@ -58,7 +71,8 @@ def change_base_model(repo_id: str, cn_on: bool):
58
  #progress(0, desc=f"Loading model: {repo_id}")
59
  print(f"Loading model: {repo_id}")
60
  clear_cache()
61
- pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype)
 
62
  last_model = repo_id
63
  last_cn_on = cn_on
64
  #progress(1, desc=f"Model loaded: {repo_id}")
@@ -70,12 +84,6 @@ def change_base_model(repo_id: str, cn_on: bool):
70
 
71
  change_base_model.zerogpu = True
72
 
73
- # Load LoRAs from JSON file
74
- with open('loras.json', 'r') as f:
75
- loras = json.load(f)
76
-
77
- MAX_SEED = 2**32-1
78
-
79
  class calculateDuration:
80
  def __init__(self, activity_name=""):
81
  self.activity_name = activity_name
@@ -115,9 +123,13 @@ def update_selection(evt: gr.SelectData, width, height):
115
  @spaces.GPU(duration=70)
116
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress=gr.Progress(track_tqdm=True)):
117
  global pipe
 
 
118
  global controlnet
119
  global controlnet_union
120
  try:
 
 
121
  pipe.to("cuda")
122
  generator = torch.Generator(device="cuda").manual_seed(seed)
123
 
@@ -126,7 +138,7 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
126
  modes, images, scales = get_control_params()
127
  if not cn_on or len(modes) == 0:
128
  progress(0, desc="Start Inference.")
129
- image = pipe(
130
  prompt=prompt_mash,
131
  num_inference_steps=steps,
132
  guidance_scale=cfg_scale,
@@ -134,12 +146,15 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
134
  height=height,
135
  generator=generator,
136
  joint_attention_kwargs={"scale": lora_scale},
137
- ).images[0]
 
 
 
138
  else:
139
  progress(0, desc="Start Inference with ControlNet.")
140
  if controlnet is not None: controlnet.to("cuda")
141
  if controlnet_union is not None: controlnet_union.to("cuda")
142
- image = pipe(
143
  prompt=prompt_mash,
144
  control_image=images,
145
  control_mode=modes,
@@ -150,15 +165,19 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
150
  controlnet_conditioning_scale=scales,
151
  generator=generator,
152
  joint_attention_kwargs={"scale": lora_scale},
153
- ).images[0]
 
154
  except Exception as e:
155
  print(e)
156
  raise gr.Error(f"Inference Error: {e}")
157
- return image
158
 
159
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
160
  lora_scale, lora_json, cn_on, progress=gr.Progress(track_tqdm=True)):
161
  global pipe
 
 
 
 
162
  if selected_index is None and not is_valid_lora(lora_json):
163
  gr.Info("LoRA isn't selected.")
164
  # raise gr.Error("You must select a LoRA before proceeding.")
@@ -197,17 +216,23 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
197
  seed = random.randint(0, MAX_SEED)
198
 
199
  progress(0, desc="Running Inference.")
200
-
201
- image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress)
 
 
 
 
202
  if is_valid_lora(lora_json):
203
  pipe.unfuse_lora()
204
  pipe.unload_lora_weights()
205
  if selected_index is not None: pipe.unload_lora_weights()
206
  pipe.to("cpu")
 
 
207
  if controlnet is not None: controlnet.to("cpu")
208
  if controlnet_union is not None: controlnet_union.to("cpu")
209
  clear_cache()
210
- return image, seed
211
 
212
  def get_huggingface_safetensors(link):
213
  split_link = link.split("/")
 
4
  import logging
5
  import torch
6
  from PIL import Image
7
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
8
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
  from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel
10
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
11
  import copy
 
22
  from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
23
  from tagger.fl2flux import predict_tags_fl2_flux
24
 
25
+ # Load LoRAs from JSON file
26
+ with open('loras.json', 'r') as f:
27
+ loras = json.load(f)
28
+
29
+ dtype = torch.bfloat16
30
+ #dtype = torch.float8_e4m3fn
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
  # Initialize the base model
33
  base_model = models[0]
34
  controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
35
  #controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
36
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
37
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
38
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
39
  controlnet_union = None
40
  controlnet = None
41
  last_model = models[0]
42
  last_cn_on = False
43
 
44
+ MAX_SEED = 2**32-1
45
+
46
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
47
+
48
  # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
49
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
50
  def change_base_model(repo_id: str, cn_on: bool):
 
53
  global controlnet
54
  global last_model
55
  global last_cn_on
 
 
56
  try:
57
  if (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
58
  if cn_on:
 
62
  controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)
63
  controlnet = FluxMultiControlNetModel([controlnet_union])
64
  pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype)
65
+ #pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
66
  last_model = repo_id
67
  last_cn_on = cn_on
68
  #progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
 
71
  #progress(0, desc=f"Loading model: {repo_id}")
72
  print(f"Loading model: {repo_id}")
73
  clear_cache()
74
+ pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype, vae=taef1)
75
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
76
  last_model = repo_id
77
  last_cn_on = cn_on
78
  #progress(1, desc=f"Model loaded: {repo_id}")
 
84
 
85
  change_base_model.zerogpu = True
86
 
 
 
 
 
 
 
87
  class calculateDuration:
88
  def __init__(self, activity_name=""):
89
  self.activity_name = activity_name
 
123
  @spaces.GPU(duration=70)
124
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress=gr.Progress(track_tqdm=True)):
125
  global pipe
126
+ global taef1
127
+ global good_vae
128
  global controlnet
129
  global controlnet_union
130
  try:
131
+ good_vae.to("cuda")
132
+ taef1.to("cuda")
133
  pipe.to("cuda")
134
  generator = torch.Generator(device="cuda").manual_seed(seed)
135
 
 
138
  modes, images, scales = get_control_params()
139
  if not cn_on or len(modes) == 0:
140
  progress(0, desc="Start Inference.")
141
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
142
  prompt=prompt_mash,
143
  num_inference_steps=steps,
144
  guidance_scale=cfg_scale,
 
146
  height=height,
147
  generator=generator,
148
  joint_attention_kwargs={"scale": lora_scale},
149
+ output_type="pil",
150
+ good_vae=good_vae,
151
+ ):
152
+ yield img
153
  else:
154
  progress(0, desc="Start Inference with ControlNet.")
155
  if controlnet is not None: controlnet.to("cuda")
156
  if controlnet_union is not None: controlnet_union.to("cuda")
157
+ for img in pipe(
158
  prompt=prompt_mash,
159
  control_image=images,
160
  control_mode=modes,
 
165
  controlnet_conditioning_scale=scales,
166
  generator=generator,
167
  joint_attention_kwargs={"scale": lora_scale},
168
+ ).images:
169
+ yield img
170
  except Exception as e:
171
  print(e)
172
  raise gr.Error(f"Inference Error: {e}")
 
173
 
174
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
175
  lora_scale, lora_json, cn_on, progress=gr.Progress(track_tqdm=True)):
176
  global pipe
177
+ global taef1
178
+ global good_vae
179
+ global controlnet
180
+ global controlnet_union
181
  if selected_index is None and not is_valid_lora(lora_json):
182
  gr.Info("LoRA isn't selected.")
183
  # raise gr.Error("You must select a LoRA before proceeding.")
 
216
  seed = random.randint(0, MAX_SEED)
217
 
218
  progress(0, desc="Running Inference.")
219
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress)
220
+ # Consume the generator to get the final image
221
+ final_image = None
222
+ for image in image_generator:
223
+ final_image = image
224
+ yield image, seed # Yield intermediate images and seed
225
  if is_valid_lora(lora_json):
226
  pipe.unfuse_lora()
227
  pipe.unload_lora_weights()
228
  if selected_index is not None: pipe.unload_lora_weights()
229
  pipe.to("cpu")
230
+ good_vae.to("cpu")
231
+ taef1.to("cpu")
232
  if controlnet is not None: controlnet.to("cpu")
233
  if controlnet_union is not None: controlnet_union.to("cpu")
234
  clear_cache()
235
+ return final_image, seed # Return the final image and seed
236
 
237
  def get_huggingface_safetensors(link):
238
  split_link = link.split("/")