sayakpaul HF staff commited on
Commit
78f1f97
1 Parent(s): 5d813dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -5
app.py CHANGED
@@ -4,6 +4,7 @@ import json
4
 
5
  component_filter = ["scheduler", "safety_checker", "tokenizer"]
6
 
 
7
  def format_size(num: int) -> str:
8
  """Format size in bytes into a human-readable string.
9
  Taken from https://stackoverflow.com/a/1094933
@@ -15,6 +16,7 @@ def format_size(num: int) -> str:
15
  num_f /= 1000.0
16
  return f"{num_f:.1f}Y"
17
 
 
18
  def format_output(pipeline_id, memory_mapping):
19
  markdown_str = f"## {pipeline_id}\n"
20
  if memory_mapping:
@@ -22,12 +24,14 @@ def format_output(pipeline_id, memory_mapping):
22
  markdown_str += f"* {component}: {format_size(memory)}\n"
23
  return markdown_str
24
 
 
25
  def load_model_index(pipeline_id, token=None, revision=None):
26
  index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token)
27
  with open(index_path, "r") as f:
28
  index_dict = json.load(f)
29
  return index_dict
30
 
 
31
  def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
32
  if token == "":
33
  token = None
@@ -48,17 +52,22 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
48
  index_filter.extend(["_class_name", "_diffusers_version"])
49
  for current_component in index_dict:
50
  if current_component not in index_filter:
51
- current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
52
  if current_component_fileobjs:
53
  current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
54
- condition = lambda filename: extension in filename and variant in filename if variant is not None else lambda filename: extension in filename
 
 
 
 
55
  variant_present_with_extension = any(condition(filename) for filename in current_component_filenames)
56
  if not variant_present_with_extension:
57
- raise ValueError(f"Requested extension ({extension}) and variant ({variant}) not present for {current_component}. Available files for this component:\n{current_component_filenames}.")
 
 
58
  else:
59
  raise ValueError(f"Problem with {current_component}.")
60
 
61
-
62
  # Handle text encoder separately when it's sharded.
63
  is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
64
  component_wise_memory = {}
@@ -99,4 +108,37 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
99
  print(selected_file.rfilename)
100
  component_wise_memory[component] = selected_file.size
101
 
102
- return format_output(pipeline_id, component_wise_memory)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  component_filter = ["scheduler", "safety_checker", "tokenizer"]
6
 
7
+
8
  def format_size(num: int) -> str:
9
  """Format size in bytes into a human-readable string.
10
  Taken from https://stackoverflow.com/a/1094933
 
16
  num_f /= 1000.0
17
  return f"{num_f:.1f}Y"
18
 
19
+
20
  def format_output(pipeline_id, memory_mapping):
21
  markdown_str = f"## {pipeline_id}\n"
22
  if memory_mapping:
 
24
  markdown_str += f"* {component}: {format_size(memory)}\n"
25
  return markdown_str
26
 
27
+
28
  def load_model_index(pipeline_id, token=None, revision=None):
29
  index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token)
30
  with open(index_path, "r") as f:
31
  index_dict = json.load(f)
32
  return index_dict
33
 
34
+
35
  def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
36
  if token == "":
37
  token = None
 
52
  index_filter.extend(["_class_name", "_diffusers_version"])
53
  for current_component in index_dict:
54
  if current_component not in index_filter:
55
+ current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
56
  if current_component_fileobjs:
57
  current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
58
+ condition = (
59
+ lambda filename: extension in filename and variant in filename
60
+ if variant is not None
61
+ else lambda filename: extension in filename
62
+ )
63
  variant_present_with_extension = any(condition(filename) for filename in current_component_filenames)
64
  if not variant_present_with_extension:
65
+ raise ValueError(
66
+ f"Requested extension ({extension}) and variant ({variant}) not present for {current_component}. Available files for this component:\n{current_component_filenames}."
67
+ )
68
  else:
69
  raise ValueError(f"Problem with {current_component}.")
70
 
 
71
  # Handle text encoder separately when it's sharded.
72
  is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
73
  component_wise_memory = {}
 
108
  print(selected_file.rfilename)
109
  component_wise_memory[component] = selected_file.size
110
 
111
+ return format_output(pipeline_id, component_wise_memory)
112
+
113
+
114
+ gr.Interface(
115
+ title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
116
+ description="Sizes will be reported in GB. Pipelines containing text encoders with sharded checkpoints are also supported (PixArt-Alpha, for example) 🤗",
117
+ fn=get_component_wise_memory,
118
+ inputs=[
119
+ gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
120
+ gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
121
+ gr.components.Dropdown(
122
+ [
123
+ "fp32",
124
+ "fp16",
125
+ ],
126
+ label="variant",
127
+ info="Precision to use for calculation.",
128
+ ),
129
+ gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."),
130
+ gr.components.Dropdown(
131
+ [".bin", ".safetensors"],
132
+ label="extension",
133
+ info="Extension to use.",
134
+ ),
135
+ ],
136
+ outputs=[gr.Markdown(label="Output")],
137
+ examples=[
138
+ ["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
139
+ ["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
140
+ ["PixArt-alpha/PixArt-XL-2-1024-MS", None, "fp32", None, ".safetensors"],
141
+ ],
142
+ theme=gr.themes.Soft(),
143
+ allow_flagging=False,
144
+ ).launch(show_error=True)