Nathan Habib
commit
bb3c194
raw history blame
No virus
3.42 kB
import pandas as pd
from datasets import load_dataset
import os
import json
from pprint import pprint
pd.options.plotting.backend = "plotly"
MODELS = [
"mistralai__Mistral-7B-Instruct-v0.2",
# "HuggingFaceH4__zephyr-7b-beta",
# "meta-llama__Llama-2-7b-chat-hf",
# "01-ai__Yi-34B-Chat",
]
HF_TOKEN = os.getenv("HF_TOKEN")
score_turn = {
1: "multi_turn",
0: "single_turn",
}
def get_dataframe_lighteval() -> pd.DataFrame:
samples = []
scores = []
for model in MODELS:
details_lighteval = load_dataset(
f"SaylorTwift/details_{model}_private",
"extended_mt_bench_0",
split="latest",
token=HF_TOKEN,
)
for d in details_lighteval:
judement_prompt = d["judement_prompt"]
judgement = d["judgement"]
predictions = d["predictions"][0]
prompts = d["full_prompt"]
turns = []
for turn in range(len(predictions)):
if turn == 1:
prompt = prompts[turn].format(model_response=predictions[turn - 1])
else:
prompt = prompts[turn]
turns.append([])
turns[turn].append(prompt)
turns[turn].append(predictions[turn])
turns[turn].append(judement_prompt[turn])
turns[turn].append(judgement[turn])
for i, turn in enumerate(turns):
samples.append(
{
"model": model,
"turn": i,
"prompt": turn[0],
"response": turn[1],
"judgement_prompt": turn[2],
"judgment": turn[3],
"score": d["metrics"][score_turn[i]],
"question_id": d["specifics"]["id"],
}
)
dataframe_all_samples = pd.DataFrame(samples)
return dataframe_all_samples
def construct_dataframe() -> pd.DataFrame:
"""
Construct a dataframe from the data in the data folder
"""
lighteval = get_dataframe_lighteval()
lighteval["model"] = lighteval["model"].apply(lambda x: x.split("__")[1])
lighteval = lighteval.set_index(["question_id", "turn", "model"])
all_samples = lighteval.reset_index()
all_samples = all_samples.set_index("question_id")
return all_samples.dropna()
def create_plot(model: str, dataframe: pd.DataFrame):
new = dataframe[dataframe["model"] == model].dropna()
new = new[new["turn"] == 1]
new["score_lighteval"] = new["score_lighteval"].astype(int)
new["score_mt_bench"] = new["score_mt_bench"].astype(int)
new = new[['score_lighteval', 'score_mt_bench']]
new.index = new.index.astype(str)
fig = new.plot.bar(title="Scores", labels={"index": "Index", "value": "Score"}, barmode="group")
return fig
def get_scores(dataframe):
dataframe = dataframe.dropna()
dataframe["score"] = dataframe["score"].astype(int)
new = dataframe[['score', "turn", "model"]]
new = new.groupby(["model", "turn"]).mean()
new = new.groupby(["model"]).mean()
return new
if __name__ == "__main__":
df = construct_dataframe()
from pprint import pprint
pprint(df)
#print(df.iloc[130])
# model = "zephyr-7b-beta"
# fig = create_plot(model, df)
# fig.show()