Text-to-Image
daoyuan98 commited on
Commit
52e7089
1 Parent(s): ef7c6c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -18
app.py CHANGED
@@ -1,12 +1,17 @@
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
  import spaces
5
  import torch
6
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
 
 
7
  from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
8
 
9
- from model import Flux
10
 
11
  def calculate_shift(
12
  image_seq_len,
@@ -174,20 +179,17 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
174
  @dataclass
175
  class ModelSpec:
176
  params: FluxParams
177
- ae_params: AutoEncoderParams
178
- ckpt_path: str
179
- ae_path: str
180
  repo_id: str
181
  repo_flow: str
182
  repo_ae: str
183
  repo_id_ae: str
184
 
 
185
  config = ModelSpec(
186
  repo_id="TencentARC/flux-mini",
187
  repo_flow="flux-mini.safetensors",
188
  repo_id_ae="black-forest-labs/FLUX.1-dev",
189
  repo_ae="ae.safetensors",
190
- ckpt_path=os.getenv("FLUX_MINI", None),
191
  params=FluxParams(
192
  in_channels=64,
193
  vec_in_dim=768,
@@ -202,35 +204,33 @@ config = ModelSpec(
202
  qkv_bias=True,
203
  guidance_embed=True,
204
  )
 
205
 
206
 
207
- def load_flow_model2(device: str = "cuda", hf_download: bool = True):
208
- if (
209
- and config.repo_id is not None
210
  and config.repo_flow is not None
211
  and hf_download
212
  ):
213
- ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
214
 
215
- model = Flux(params)
216
  if ckpt_path is not None:
217
  sd = load_sft(ckpt_path, device=str(device))
218
  missing, unexpected = model.load_state_dict(sd, strict=True)
219
  return model
220
 
221
 
222
-
223
-
224
  dtype = torch.bfloat16
225
  device = "cuda" if torch.cuda.is_available() else "cpu"
226
 
227
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler").to(device)
228
  vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
229
  text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder").to(device)
230
- tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer").to(device)
231
  text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2").to(device)
232
- tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2").to(device)
233
- transformer = load_flow_model2(device)
234
 
235
  pipe = FluxPipeline(
236
  scheduler,
@@ -238,7 +238,7 @@ pipe = FluxPipeline(
238
  text_encoder,
239
  tokenizer,
240
  text_encoder_2,
241
- tokenizer_2
242
  transformer
243
  )
244
  torch.cuda.empty_cache()
 
1
+ from dataclasses import dataclass
2
+ from typing import Union, Optional, List, Any, Dict
3
+
4
  import gradio as gr
5
  import numpy as np
6
  import random
7
  import spaces
8
  import torch
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, FluxPipeline
12
  from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
13
 
14
+ from model import Flux, FluxParams
15
 
16
  def calculate_shift(
17
  image_seq_len,
 
179
  @dataclass
180
  class ModelSpec:
181
  params: FluxParams
 
 
 
182
  repo_id: str
183
  repo_flow: str
184
  repo_ae: str
185
  repo_id_ae: str
186
 
187
+
188
  config = ModelSpec(
189
  repo_id="TencentARC/flux-mini",
190
  repo_flow="flux-mini.safetensors",
191
  repo_id_ae="black-forest-labs/FLUX.1-dev",
192
  repo_ae="ae.safetensors",
 
193
  params=FluxParams(
194
  in_channels=64,
195
  vec_in_dim=768,
 
204
  qkv_bias=True,
205
  guidance_embed=True,
206
  )
207
+ )
208
 
209
 
210
+ def load_flow_model2(config, device: str = "cuda", hf_download: bool = True):
211
+ if (config.repo_id is not None
 
212
  and config.repo_flow is not None
213
  and hf_download
214
  ):
215
+ ckpt_path = hf_hub_download(config.repo_id, config.repo_flow.replace("sft", "safetensors"))
216
 
217
+ model = Flux(config.params)
218
  if ckpt_path is not None:
219
  sd = load_sft(ckpt_path, device=str(device))
220
  missing, unexpected = model.load_state_dict(sd, strict=True)
221
  return model
222
 
223
 
 
 
224
  dtype = torch.bfloat16
225
  device = "cuda" if torch.cuda.is_available() else "cpu"
226
 
227
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler")
228
  vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
229
  text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder").to(device)
230
+ tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer")
231
  text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2").to(device)
232
+ tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2")
233
+ transformer = load_flow_model2(config, device)
234
 
235
  pipe = FluxPipeline(
236
  scheduler,
 
238
  text_encoder,
239
  tokenizer,
240
  text_encoder_2,
241
+ tokenizer_2,
242
  transformer
243
  )
244
  torch.cuda.empty_cache()