open_dutch_llm_leaderboard / generate_overview_json.py
Bram Vanroy
add data collection script
78588de
raw
history blame
1.43 kB
from pathlib import Path
import json
from pprint import pprint
from transformers import AutoModelForCausalLM
def get_num_parameters(model_name: str) -> int:
return AutoModelForCausalLM.from_pretrained(model_name).num_parameters()
def main():
results = {}
for pfin in Path(__file__).parent.joinpath("evals").rglob("*.json"):
if pfin.stem == "models":
continue
short_name = pfin.stem.split("_")[2]
if short_name not in results:
results[short_name] = {}
data = json.loads(pfin.read_text(encoding="utf-8"))
if "config" not in data:
continue
config = data["config"]
if "model_args" not in config:
continue
model_args = dict(params.split("=") for params in config["model_args"].split(","))
if "pretrained" not in model_args:
continue
results[short_name]["model_name"] = model_args["pretrained"]
results[short_name]["compute_dtype"] = model_args.get("dtype", None)
results[short_name]["quantization"] = None
if "load_in_8bit" in model_args:
results[short_name]["quantization"] = "8-bit"
elif "load_in_4bit" in model_args:
results[short_name]["quantization"] = "4-bit"
results[short_name]["num_parameters"] = get_num_parameters(model_args["pretrained"])
pprint(results)
if __name__ == '__main__':
main()