sayakpaul's picture
sayakpaul HF staff
Fetch files metadata directly from model_info call (#1)
bf03ba4 verified
raw
history blame
No virus
2.86 kB
from huggingface_hub import model_info
import gradio as gr
def bytes_to_giga_bytes(bytes):
return f"{(bytes / 1024 / 1024 / 1024):.3f}"
def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
if token == "":
token = None
if variant == "":
variant = None
if revision == "":
revision = None
if variant == "fp32":
variant = None
print(pipeline_id, variant, revision, extension)
component_wise_memory = {}
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
for current_file in files_in_repo:
if all(
substring not in current_file.rfilename
for substring in ["scheduler", "feature_extractor", "safety_checker", "tokenizer"]
):
is_folder = len(current_file.rfilename.split("/")) == 2
if is_folder:
filename = None
if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
component = current_file.rfilename.split("/")[0]
if (
variant is not None
and variant in current_file.rfilename
and "ema" not in current_file.rfilename
):
filename = current_file.rfilename
elif "ema" not in current_file.rfilename:
filename = current_file.rfilename
if filename is not None:
component_wise_memory[component] = bytes_to_giga_bytes(current_file.size)
return component_wise_memory
gr.Interface(
title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
description="Sizes will be reported in GB.",
fn=get_component_wise_memory,
inputs=[
gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
gr.components.Dropdown(
[
"fp32",
"fp16",
],
label="variant",
info="Precision to use for calculation.",
),
gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."),
gr.components.Dropdown(
[".bin", ".safetensors"],
label="extension",
info="Extension to use.",
),
],
outputs="text",
examples=[
["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
],
theme=gr.themes.Soft(),
allow_flagging=False,
).launch()