sayakpaul HF staff commited on
Commit
3f6a1fe
1 Parent(s): db390f9

add support for controlnet and t2i adapter too

Browse files
Files changed (1) hide show
  1. app.py +99 -10
app.py CHANGED
@@ -11,6 +11,24 @@ COMPONENT_FILTER = [
11
  "_diffusers_version",
12
  ]
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def format_size(num: int) -> str:
16
  """Format size in bytes into a human-readable string.
@@ -24,11 +42,21 @@ def format_size(num: int) -> str:
24
  return f"{num_f:.1f}Y"
25
 
26
 
27
- def format_output(pipeline_id, memory_mapping):
28
  markdown_str = f"## {pipeline_id}\n"
 
29
  if memory_mapping:
30
  for component, memory in memory_mapping.items():
31
  markdown_str += f"* {component}: {format_size(memory)}\n"
 
 
 
 
 
 
 
 
 
32
  return markdown_str
33
 
34
 
@@ -39,7 +67,35 @@ def load_model_index(pipeline_id, token=None, revision=None):
39
  return index_dict
40
 
41
 
42
- def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  if token == "":
44
  token = None
45
 
@@ -49,12 +105,31 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
49
  if variant == "fp32":
50
  variant = None
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}")
53
 
 
54
  files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
55
  index_dict = load_model_index(pipeline_id, token=token, revision=revision)
56
 
57
- # Check if all the concerned components have the checkpoints in the requested "variant" and "extension".
 
58
  print(f"Index dict: {index_dict}")
59
  for current_component in index_dict:
60
  if (
@@ -63,6 +138,7 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
63
  and len(index_dict[current_component]) == 2
64
  ):
65
  current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
 
66
  if current_component_fileobjs:
67
  current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
68
  condition = ( # noqa: E731
@@ -119,16 +195,20 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
119
  if selected_file is not None:
120
  component_wise_memory[component] = selected_file.size
121
 
122
- return format_output(pipeline_id, component_wise_memory)
123
 
124
 
125
  with gr.Interface(
126
  title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
127
  description="Pipelines containing text encoders with sharded checkpoints are also supported"
128
- " (PixArt-Alpha, for example) 🤗",
 
 
129
  fn=get_component_wise_memory,
130
  inputs=[
131
  gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
 
 
132
  gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
133
  gr.components.Radio(
134
  ["fp32", "fp16", "bf16"],
@@ -144,11 +224,20 @@ with gr.Interface(
144
  ],
145
  outputs=[gr.Markdown(label="Output")],
146
  examples=[
147
- ["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
148
- ["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
149
- ["PixArt-alpha/PixArt-XL-2-1024-MS", None, "fp32", None, ".safetensors"],
150
- ["stabilityai/stable-cascade", None, "bf16", None, ".safetensors"],
151
- ["Deci/DeciDiffusion-v2-0", None, "fp32", None, ".safetensors"],
 
 
 
 
 
 
 
 
 
152
  ],
153
  theme=gr.themes.Soft(),
154
  allow_flagging="never",
 
11
  "_diffusers_version",
12
  ]
13
 
14
+ ARTICLE = """
15
+ ## Notes on how to use the `controlnet_id` and `t2i_adapter_id` fields
16
+
17
+ Both `controlnet_id` and `t2i_adapter_id` fields support passing multiple checkpoint ids,
18
+ e.g., "thibaud/controlnet-openpose-sdxl-1.0,diffusers/controlnet-canny-sdxl-1.0". For
19
+ `t2i_adapter_id`, this could be like - "TencentARC/t2iadapter_keypose_sd14v1,TencentARC/t2iadapter_depth_sd14v1".
20
+
21
+ Users should take care of passing the underlying base `pipeline_id` appropriately. For example,
22
+ passing `pipeline_id` as "runwayml/stable-diffusion-v1-5" and `controlnet_id` as "thibaud/controlnet-openpose-sdxl-1.0"
23
+ won't result in an error but these two things aren't meant to compatible. You should pass
24
+ a `controlnet_id` that is compatible with "runwayml/stable-diffusion-v1-5".
25
+
26
+ For further clarification on this topic, feel free to open a [discussion](https://huggingface.co/spaces/diffusers/compute-pipeline-size/discussions).
27
+
28
+ 📔 Also, note that `revision` field is only reserved for `pipeline_id`. It won't have any effect on the
29
+ `controlnet_id` or `t2i_adapter_id`.
30
+ """
31
+
32
 
33
  def format_size(num: int) -> str:
34
  """Format size in bytes into a human-readable string.
 
42
  return f"{num_f:.1f}Y"
43
 
44
 
45
+ def format_output(pipeline_id, memory_mapping, controlnet_mapping=None, t2i_adapter_mapping=None):
46
  markdown_str = f"## {pipeline_id}\n"
47
+
48
  if memory_mapping:
49
  for component, memory in memory_mapping.items():
50
  markdown_str += f"* {component}: {format_size(memory)}\n"
51
+ if controlnet_mapping:
52
+ markdown_str += "\n## ControlNet(s)\n"
53
+ for controlnet_id, memory in controlnet_mapping.items():
54
+ markdown_str += f"* {controlnet_id}: {format_size(memory)}\n"
55
+ if t2i_adapter_mapping:
56
+ markdown_str += "\n## T2I-Adapters(s)\n"
57
+ for t2_adapter_id, memory in t2i_adapter_mapping.items():
58
+ markdown_str += f"* {t2_adapter_id}: {format_size(memory)}\n"
59
+
60
  return markdown_str
61
 
62
 
 
67
  return index_dict
68
 
69
 
70
+ def get_individual_model_memory(id, token, variant, extension):
71
+ files_in_repo = model_info(id, token=token, files_metadata=True).siblings
72
+ for x in files_in_repo:
73
+ if extension in x.rfilename:
74
+ if variant:
75
+ if variant in x.rfilename:
76
+ return x.size
77
+ else:
78
+ return x.size
79
+
80
+
81
+ def get_component_wise_memory(
82
+ pipeline_id,
83
+ controlnet_id=None,
84
+ t2i_adapter_id=None,
85
+ token=None,
86
+ variant=None,
87
+ revision=None,
88
+ extension=".safetensors",
89
+ ):
90
+ if controlnet_id == "":
91
+ controlnet_id = None
92
+
93
+ if t2i_adapter_id == "":
94
+ t2i_adapter_id = None
95
+
96
+ if controlnet_id and t2i_adapter_id:
97
+ raise ValueError("Both `controlnet_id` and `t2i_adapter_id` cannot be provided.")
98
+
99
  if token == "":
100
  token = None
101
 
 
105
  if variant == "fp32":
106
  variant = None
107
 
108
+ # Handle ControlNet and T2I-Adapter.
109
+ controlnet_mapping = t2_adapter_mapping = None
110
+ if controlnet_id is not None:
111
+ controlnet_id = controlnet_id.split(",")
112
+ controlnet_sizes = [
113
+ get_individual_model_memory(id_, token=token, variant=variant, extension=extension)
114
+ for id_ in controlnet_id
115
+ ]
116
+ controlnet_mapping = dict(zip(controlnet_id, controlnet_sizes))
117
+ elif t2i_adapter_id is not None:
118
+ t2i_adapter_id = t2i_adapter_id.split(",")
119
+ t2i_adapter_sizes = [
120
+ get_individual_model_memory(id_, token=token, variant=variant, extension=extension)
121
+ for id_ in t2i_adapter_id
122
+ ]
123
+ t2_adapter_mapping = dict(zip(t2i_adapter_id, t2i_adapter_sizes))
124
+
125
  print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}")
126
 
127
+ # Load pipeline metadata.
128
  files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
129
  index_dict = load_model_index(pipeline_id, token=token, revision=revision)
130
 
131
+ # Check if all the concerned components have the checkpoints in
132
+ # the requested "variant" and "extension".
133
  print(f"Index dict: {index_dict}")
134
  for current_component in index_dict:
135
  if (
 
138
  and len(index_dict[current_component]) == 2
139
  ):
140
  current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
141
+
142
  if current_component_fileobjs:
143
  current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
144
  condition = ( # noqa: E731
 
195
  if selected_file is not None:
196
  component_wise_memory[component] = selected_file.size
197
 
198
+ return format_output(pipeline_id, component_wise_memory, controlnet_mapping, t2_adapter_mapping)
199
 
200
 
201
  with gr.Interface(
202
  title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
203
  description="Pipelines containing text encoders with sharded checkpoints are also supported"
204
+ " (PixArt-Alpha, for example) 🤗 See instructions below the form on how to pass"
205
+ " `controlnet_id` or `t2_adapter_id`.",
206
+ article=ARTICLE,
207
  fn=get_component_wise_memory,
208
  inputs=[
209
  gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
210
+ gr.components.Textbox(lines=1, label="controlnet_id", info="Example: lllyasviel/sd-controlnet-canny"),
211
+ gr.components.Textbox(lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1"),
212
  gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
213
  gr.components.Radio(
214
  ["fp32", "fp16", "bf16"],
 
224
  ],
225
  outputs=[gr.Markdown(label="Output")],
226
  examples=[
227
+ ["runwayml/stable-diffusion-v1-5", None, None, None, "fp32", None, ".safetensors"],
228
+ ["PixArt-alpha/PixArt-XL-2-1024-MS", None, None, None, "fp32", None, ".safetensors"],
229
+ ["runwayml/stable-diffusion-v1-5", "lllyasviel/sd-controlnet-canny", None, None, "fp32", None, ".safetensors"],
230
+ [
231
+ "stabilityai/stable-diffusion-xl-base-1.0",
232
+ None,
233
+ "TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0",
234
+ None,
235
+ "fp32",
236
+ None,
237
+ ".safetensors",
238
+ ],
239
+ ["stabilityai/stable-cascade", None, None, None, "bf16", None, ".safetensors"],
240
+ ["Deci/DeciDiffusion-v2-0", None, None, None, "fp32", None, ".safetensors"],
241
  ],
242
  theme=gr.themes.Soft(),
243
  allow_flagging="never",