chenyangqi commited on
Commit
43ca864
1 Parent(s): 590b13e

add cache_ckpt option

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. app_fatezero.py +4 -1
  3. inference_fatezero.py +31 -24
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  trash/*
2
- tmp
 
 
1
  trash/*
2
+ tmp
3
+ gradio_cached_examples
app_fatezero.py CHANGED
@@ -33,6 +33,7 @@ pipe = merge_config_then_run()
33
 
34
  with gr.Blocks(css='style.css') as demo:
35
  # gr.Markdown(TITLE)
 
36
  gr.HTML(
37
  """
38
  <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
@@ -291,7 +292,9 @@ with gr.Blocks(css='style.css') as demo:
291
  ],
292
  outputs=result,
293
  fn=pipe.run,
294
- cache_examples=os.getenv('SYSTEM') == 'spaces')
 
 
295
 
296
  # model_id.change(fn=app.load_model_info,
297
  # inputs=model_id,
 
33
 
34
  with gr.Blocks(css='style.css') as demo:
35
  # gr.Markdown(TITLE)
36
+
37
  gr.HTML(
38
  """
39
  <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
 
292
  ],
293
  outputs=result,
294
  fn=pipe.run,
295
+ cache_examples=True,
296
+ # cache_examples=os.getenv('SYSTEM') == 'spaces'
297
+ )
298
 
299
  # model_id.change(fn=app.load_model_info,
300
  # inputs=model_id,
inference_fatezero.py CHANGED
@@ -8,32 +8,39 @@ class merge_config_then_run():
8
  def __init__(self) -> None:
9
  # Load the tokenizer
10
  pretrained_model_path = 'FateZero/ckpt/stable-diffusion-v1-4'
11
- self.tokenizer = AutoTokenizer.from_pretrained(
12
- pretrained_model_path,
13
- # 'FateZero/ckpt/stable-diffusion-v1-4',
14
- subfolder="tokenizer",
15
- use_fast=False,
16
- )
 
 
 
 
 
 
 
17
 
18
- # Load models and create wrapper for stable diffusion
19
- self.text_encoder = CLIPTextModel.from_pretrained(
20
- pretrained_model_path,
21
- subfolder="text_encoder",
22
- )
23
 
24
- self.vae = AutoencoderKL.from_pretrained(
25
- pretrained_model_path,
26
- subfolder="vae",
27
- )
28
- model_config = {
29
- "lora": 160,
30
- # temporal_downsample_time: 4
31
- "SparseCausalAttention_index": ['mid'],
32
- "least_sc_channel": 640
33
- }
34
- self.unet = UNetPseudo3DConditionModel.from_2d_model(
35
- os.path.join(pretrained_model_path, "unet"), model_config=model_config
36
- )
37
 
38
  def run(
39
  self,
 
8
  def __init__(self) -> None:
9
  # Load the tokenizer
10
  pretrained_model_path = 'FateZero/ckpt/stable-diffusion-v1-4'
11
+ self.tokenizer = None
12
+ self.text_encoder = None
13
+ self.vae = None
14
+ self.unet = None
15
+
16
+ cache_ckpt = True
17
+ if cache_ckpt:
18
+ self.tokenizer = AutoTokenizer.from_pretrained(
19
+ pretrained_model_path,
20
+ # 'FateZero/ckpt/stable-diffusion-v1-4',
21
+ subfolder="tokenizer",
22
+ use_fast=False,
23
+ )
24
 
25
+ # Load models and create wrapper for stable diffusion
26
+ self.text_encoder = CLIPTextModel.from_pretrained(
27
+ pretrained_model_path,
28
+ subfolder="text_encoder",
29
+ )
30
 
31
+ self.vae = AutoencoderKL.from_pretrained(
32
+ pretrained_model_path,
33
+ subfolder="vae",
34
+ )
35
+ model_config = {
36
+ "lora": 160,
37
+ # temporal_downsample_time: 4
38
+ "SparseCausalAttention_index": ['mid'],
39
+ "least_sc_channel": 640
40
+ }
41
+ self.unet = UNetPseudo3DConditionModel.from_2d_model(
42
+ os.path.join(pretrained_model_path, "unet"), model_config=model_config
43
+ )
44
 
45
  def run(
46
  self,