sayakpaul HF staff commited on
Commit
601c9fd
1 Parent(s): bf03ba4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -18
app.py CHANGED
@@ -1,37 +1,63 @@
1
- from huggingface_hub import model_info
2
  import gradio as gr
 
3
 
4
 
5
  def bytes_to_giga_bytes(bytes):
6
  return f"{(bytes / 1024 / 1024 / 1024):.3f}"
7
 
8
 
 
 
 
 
 
 
 
9
  def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
10
  if token == "":
11
  token = None
12
 
13
- if variant == "":
14
- variant = None
15
-
16
  if revision == "":
17
  revision = None
18
 
19
  if variant == "fp32":
20
  variant = None
21
 
22
- print(pipeline_id, variant, revision, extension)
23
- component_wise_memory = {}
24
 
25
  files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  for current_file in files_in_repo:
28
- if all(
29
- substring not in current_file.rfilename
30
- for substring in ["scheduler", "feature_extractor", "safety_checker", "tokenizer"]
31
- ):
32
  is_folder = len(current_file.rfilename.split("/")) == 2
33
- if is_folder:
34
- filename = None
35
  if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
36
  component = current_file.rfilename.split("/")[0]
37
  if (
@@ -39,12 +65,13 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
39
  and variant in current_file.rfilename
40
  and "ema" not in current_file.rfilename
41
  ):
42
- filename = current_file.rfilename
43
- elif "ema" not in current_file.rfilename:
44
- filename = current_file.rfilename
45
 
46
- if filename is not None:
47
- component_wise_memory[component] = bytes_to_giga_bytes(current_file.size)
 
48
 
49
  return component_wise_memory
50
 
@@ -75,7 +102,8 @@ gr.Interface(
75
  examples=[
76
  ["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
77
  ["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
 
78
  ],
79
  theme=gr.themes.Soft(),
80
  allow_flagging=False,
81
- ).launch()
 
1
+ from huggingface_hub import model_info, hf_hub_download
2
  import gradio as gr
3
+ import json
4
 
5
 
6
  def bytes_to_giga_bytes(bytes):
7
  return f"{(bytes / 1024 / 1024 / 1024):.3f}"
8
 
9
 
10
+ def load_model_index(pipeline_id, token=None, revision=None):
11
+ index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token)
12
+ with open(index_path, "r") as f:
13
+ index_dict = json.load(f)
14
+ return index_dict
15
+
16
+
17
  def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
18
  if token == "":
19
  token = None
20
 
 
 
 
21
  if revision == "":
22
  revision = None
23
 
24
  if variant == "fp32":
25
  variant = None
26
 
27
+ print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}")
 
28
 
29
  files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
30
+ index_dict = load_model_index(pipeline_id, token=token, revision=revision)
31
+
32
+ is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
33
+ component_wise_memory = {}
34
+
35
+ # Handle text encoder separately when it's sharded.
36
+ if is_text_encoder_shared:
37
+ for current_file in files_in_repo:
38
+ if "text_encoder" in current_file.rfilename:
39
+ if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
40
+ if variant is not None and variant in current_file.rfilename:
41
+ selected_file = current_file
42
+ else:
43
+ selected_file = current_file
44
+ if "text_encoder" not in component_wise_memory:
45
+ component_wise_memory["text_encoder"] = selected_file.size
46
+ else:
47
+ component_wise_memory["text_encoder"] += selected_file.size
48
+
49
+ print(component_wise_memory)
50
+
51
+ # Handle pipeline components.
52
+ component_filter = ["scheduler", "feature_extractor", "safety_checker", "tokenizer"]
53
+ if is_text_encoder_shared:
54
+ component_filter.append("text_encoder")
55
 
56
  for current_file in files_in_repo:
57
+ if all(substring not in current_file.rfilename for substring in component_filter):
 
 
 
58
  is_folder = len(current_file.rfilename.split("/")) == 2
59
+ if is_folder and current_file.rfilename.split("/")[0] in index_dict:
60
+ selected_file = None
61
  if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
62
  component = current_file.rfilename.split("/")[0]
63
  if (
 
65
  and variant in current_file.rfilename
66
  and "ema" not in current_file.rfilename
67
  ):
68
+ selected_file = current_file
69
+ elif variant is None and "ema" not in current_file.rfilename:
70
+ selected_file = current_file
71
 
72
+ if selected_file is not None:
73
+ print(selected_file.rfilename)
74
+ component_wise_memory[component] = bytes_to_giga_bytes(selected_file.size)
75
 
76
  return component_wise_memory
77
 
 
102
  examples=[
103
  ["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
104
  ["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
105
+ [""],
106
  ],
107
  theme=gr.themes.Soft(),
108
  allow_flagging=False,
109
+ ).launch()