sayakpaul HF staff commited on
Commit
4b23311
1 Parent(s): d1cfed4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_url, get_hf_file_metadata, model_info
2
+ import gradio as gr
3
+
4
+
5
+ def bytes_to_giga_bytes(bytes):
6
+ return f"{(bytes / 1024 / 1024 / 1024):.3f}"
7
+
8
+
9
+ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
10
+ if token == "":
11
+ token = None
12
+
13
+ if variant == "":
14
+ variant = None
15
+
16
+ if revision == "":
17
+ revision = None
18
+
19
+ if variant == "fp32":
20
+ variant = None
21
+
22
+ print(pipeline_id, variant, revision, extension)
23
+
24
+ files_in_repo = model_info(pipeline_id, revision=revision, token=token).siblings
25
+
26
+ for current_file in files_in_repo:
27
+ if all(
28
+ substring not in current_file.rfilename
29
+ for substring in ["scheduler", "feature_extractor", "safety_checker", "tokenizer"]
30
+ ):
31
+ is_folder = len(current_file.rfilename.split("/")) == 2
32
+ if is_folder:
33
+ filename = None
34
+ if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
35
+ component = current_file.rfilename.split("/")[0]
36
+ if (
37
+ variant is not None
38
+ and variant in current_file.rfilename
39
+ and "ema" not in current_file.rfilename
40
+ ):
41
+ filename = current_file.rfilename
42
+ elif "ema" not in current_file.rfilename:
43
+ filename = current_file.rfilename
44
+
45
+ if filename is not None:
46
+ hub_url = hf_hub_url(repo_id=pipeline_id, filename=filename)
47
+ file_metadata = get_hf_file_metadata(hub_url)
48
+
49
+ component_wise_memory[component] = bytes_to_giga_bytes(file_metadata.size)
50
+
51
+ return component_wise_memory
52
+
53
+
54
+ gr.Interface(
55
+ title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
56
+ description="Sizes will be reported in GB.",
57
+ fn=get_component_wise_memory,
58
+ inputs=[
59
+ gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
60
+ gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
61
+ gr.components.Dropdown(
62
+ [
63
+ "fp32",
64
+ "fp16",
65
+ ],
66
+ label="variant",
67
+ info="Precision to use for calculation.",
68
+ ),
69
+ gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."),
70
+ gr.components.Dropdown(
71
+ [".bin", ".safetensors"],
72
+ label="extension",
73
+ info="Extension to use.",
74
+ ),
75
+ ],
76
+ outputs="text",
77
+ examples=[
78
+ ["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
79
+ ["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
80
+ ],
81
+ theme=gr.themes.Soft(),
82
+ allow_flagging=False,
83
+ ).launch()