jerukperas commited on
Commit
9b4fa71
·
verified ·
1 Parent(s): cd5f59c

update gradio & add sd_embed

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +23 -5
  3. requirements.txt +5 -2
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🦢
4
  colorFrom: gray
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.41.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: gray
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -6,14 +6,26 @@ import gradio as gr
6
  import numpy as np
7
  import random
8
  import torch
9
- from diffusers import FluxPipeline
 
 
10
 
11
  dtype = torch.bfloat16
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
- pipe = FluxPipeline.from_pretrained(
14
- "black-forest-labs/FLUX.1-dev",
15
- torch_dtype=dtype,
 
 
 
 
 
16
  )
 
 
 
 
 
17
  pipe.to(device)
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
@@ -35,8 +47,14 @@ def infer(
35
  seed = random.randint(0, MAX_SEED)
36
 
37
  generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
38
  image = pipe(
39
- prompt=prompt,
 
40
  width=width,
41
  height=height,
42
  num_inference_steps=num_inference_steps,
 
6
  import numpy as np
7
  import random
8
  import torch
9
+ from diffusers import FluxPipeline, DiffusionPipeline, FluxTransformer2DModel # noqa: F401
10
+ from torchao.quantization import quantize_, int8_weight_only
11
+ from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
12
 
13
  dtype = torch.bfloat16
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ transformer = FluxTransformer2DModel.from_pretrained(
17
+ "black-forest-labs/FLUX.1-dev", subfolder="transformer", torch_dtype=torch.bfloat16
18
+ )
19
+ quantize_(transformer, int8_weight_only())
20
+
21
+ pipe = DiffusionPipeline.from_pretrained(
22
+ "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
23
  )
24
+
25
+ # pipe = FluxPipeline.from_pretrained(
26
+ # "black-forest-labs/FLUX.1-dev",
27
+ # torch_dtype=dtype,
28
+ # )
29
  pipe.to(device)
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
 
47
  seed = random.randint(0, MAX_SEED)
48
 
49
  generator = torch.Generator().manual_seed(seed)
50
+
51
+ prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(
52
+ pipe=pipe, prompt=prompt
53
+ )
54
+
55
  image = pipe(
56
+ prompt_embeds=prompt_embeds,
57
+ pooled_prompt_embeds=pooled_prompt_embeds,
58
  width=width,
59
  height=height,
60
  num_inference_steps=num_inference_steps,
requirements.txt CHANGED
@@ -1,7 +1,10 @@
 
1
  accelerate
2
  git+https://github.com/huggingface/diffusers.git
3
  torch
4
- transformers==4.42.4
5
  xformers
6
  sentencepiece
7
- bitsandbytes
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
  accelerate
3
  git+https://github.com/huggingface/diffusers.git
4
  torch
5
+ transformers
6
  xformers
7
  sentencepiece
8
+ bitsandbytes
9
+ git+https://github.com/xhinker/sd_embed.git@main
10
+ torchao