|
import json |
|
from pathlib import Path |
|
|
|
from huggingface_hub import model_info |
|
from tqdm import tqdm |
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
def get_num_parameters(model_name: str) -> int: |
|
try: |
|
info = model_info(model_name) |
|
return info.safetensors["total"] |
|
except Exception: |
|
return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).num_parameters() |
|
|
|
|
|
def main(): |
|
evals_dir = Path(__file__).parent.joinpath("evals") |
|
pf_overview = evals_dir.joinpath("models.json") |
|
results = json.loads(pf_overview.read_text(encoding="utf-8")) if pf_overview.exists() else {} |
|
|
|
for pfin in tqdm(list(evals_dir.rglob("*.json")), desc="Generating overview JSON"): |
|
if pfin.stem == "models": |
|
continue |
|
short_name = pfin.stem.split("_", 2)[2].lower() |
|
data = json.loads(pfin.read_text(encoding="utf-8")) |
|
if "config" not in data: |
|
continue |
|
|
|
config = data["config"] |
|
if "model_args" not in config: |
|
continue |
|
|
|
model_args = dict(params.split("=") for params in config["model_args"].split(",")) |
|
if "pretrained" not in model_args: |
|
continue |
|
|
|
results[short_name] = { |
|
"model_name": model_args["pretrained"], |
|
"compute_dtype": model_args.get("dtype", None), |
|
"quantization": None, |
|
"num_parameters": results[short_name]["num_parameters"] |
|
if short_name in results and "num_parameters" in results[short_name] |
|
else get_num_parameters(model_args["pretrained"]), |
|
"model_type": results[short_name]["model_type"] |
|
if short_name in results and "model_type" in results[short_name] |
|
else "not-given", |
|
"dutch_coverage": results[short_name]["dutch_coverage"] |
|
if short_name in results and "dutch_coverage" in results[short_name] |
|
else "not-given", |
|
} |
|
|
|
if "load_in_8bit" in model_args: |
|
results[short_name]["quantization"] = "8-bit" |
|
elif "load_in_4bit" in model_args: |
|
results[short_name]["quantization"] = "4-bit" |
|
|
|
pf_overview.write_text(json.dumps(results, indent=4, sort_keys=True), encoding="utf-8") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|