sayakpaul's picture
sayakpaul HF staff
Update app.py
59db7fd verified
raw
history blame
No virus
4.92 kB
from huggingface_hub import model_info, hf_hub_download
import gradio as gr
import json
def format_size(num: int) -> str:
"""Format size in bytes into a human-readable string.
Taken from https://stackoverflow.com/a/1094933
"""
num_f = float(num)
for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
if abs(num_f) < 1000.0:
return f"{num_f:3.1f}{unit}"
num_f /= 1000.0
return f"{num_f:.1f}Y"
def format_output(memory_mapping):
markdown_str = ""
if memory_mapping:
for component, memory in memory_mapping:
markdown_str += f"* {component}: {format_size(memory)}\n"
return markdown_str
def load_model_index(pipeline_id, token=None, revision=None):
index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token)
with open(index_path, "r") as f:
index_dict = json.load(f)
return index_dict
def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
if token == "":
token = None
if revision == "":
revision = None
if variant == "fp32":
variant = None
print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}")
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
index_dict = load_model_index(pipeline_id, token=token, revision=revision)
is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
component_wise_memory = {}
# Handle text encoder separately when it's sharded.
if is_text_encoder_shared:
for current_file in files_in_repo:
if "text_encoder" in current_file.rfilename:
if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
if variant is not None and variant in current_file.rfilename:
selected_file = current_file
else:
selected_file = current_file
if "text_encoder" not in component_wise_memory:
component_wise_memory["text_encoder"] = selected_file.size
else:
component_wise_memory["text_encoder"] += selected_file.size
print(component_wise_memory)
# Handle pipeline components.
component_filter = ["scheduler", "feature_extractor", "safety_checker", "tokenizer"]
if is_text_encoder_shared:
component_filter.append("text_encoder")
for current_file in files_in_repo:
if all(substring not in current_file.rfilename for substring in component_filter):
is_folder = len(current_file.rfilename.split("/")) == 2
if is_folder and current_file.rfilename.split("/")[0] in index_dict:
selected_file = None
if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
component = current_file.rfilename.split("/")[0]
if (
variant is not None
and variant in current_file.rfilename
and "ema" not in current_file.rfilename
):
selected_file = current_file
elif variant is None and "ema" not in current_file.rfilename:
selected_file = current_file
if selected_file is not None:
print(selected_file.rfilename)
component_wise_memory[component] = selected_file.size
return format_output(component_wise_memory)
gr.Interface(
title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
description="Sizes will be reported in GB.",
fn=get_component_wise_memory,
inputs=[
gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
gr.components.Dropdown(
[
"fp32",
"fp16",
],
label="variant",
info="Precision to use for calculation.",
),
gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."),
gr.components.Dropdown(
[".bin", ".safetensors"],
label="extension",
info="Extension to use.",
),
],
outputs="markdown",
examples=[
["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
[""],
],
theme=gr.themes.Soft(),
allow_flagging=False,
).launch()