machine-translation / llm_toolkit /translation_utils.py
dh-mc's picture
fixed bug for Mistral
12a5ff3
raw
history blame
11.3 kB
import os
import re
import pandas as pd
import evaluate
import seaborn as sns
import matplotlib.pyplot as plt
from datasets import load_dataset
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from tqdm import tqdm
from eval_modules.calc_repetitions import *
from llm_toolkit.llm_utils import load_tokenizer
print(f"loading {__file__}")
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")
accuracy = evaluate.load("accuracy")
def extract_answer(text, debug=False):
if text:
# Remove the begin and end tokens
text = re.sub(
r".*?(assistant|\[/INST\]).+?\b", "", text, flags=re.DOTALL | re.MULTILINE
)
if debug:
print("--------\nstep 1:", text)
text = re.sub(r"<.+?>.*", "", text, flags=re.DOTALL | re.MULTILINE)
if debug:
print("--------\nstep 2:", text)
text = re.sub(
r".*?end_header_id\|>\n\n", "", text, flags=re.DOTALL | re.MULTILINE
)
if debug:
print("--------\nstep 3:", text)
return text
def calc_metrics(references, predictions, debug=False):
assert len(references) == len(
predictions
), f"lengths are difference: {len(references)} != {len(predictions)}"
predictions = [extract_answer(text) for text in predictions]
results = {}
results["meteor"] = meteor.compute(predictions=predictions, references=references)[
"meteor"
]
results["bleu_scores"] = bleu.compute(
predictions=predictions, references=references, max_order=4
)
results["rouge_scores"] = rouge.compute(
predictions=predictions, references=references
)
correct = [1 if ref == pred else 0 for ref, pred in zip(references, predictions)]
accuracy = sum(correct) / len(references)
results["accuracy"] = accuracy
if debug:
correct_ids = [i for i, c in enumerate(correct) if c == 1]
results["correct_ids"] = correct_ids
return results
def save_results(model_name, results_path, dataset, predictions, debug=False):
if not os.path.exists(results_path):
# Get the directory part of the file path
dir_path = os.path.dirname(results_path)
# Create all directories in the path (if they don't exist)
os.makedirs(dir_path, exist_ok=True)
df = dataset.to_pandas()
df.drop(columns=["text", "prompt"], inplace=True)
else:
df = pd.read_csv(results_path, on_bad_lines="warn")
df[model_name] = predictions
if debug:
print(df.head(1))
df.to_csv(results_path, index=False)
def load_translation_dataset(data_path, tokenizer=None):
train_data_file = data_path.replace(".tsv", "-train.tsv")
test_data_file = data_path.replace(".tsv", "-test.tsv")
if not os.path.exists(train_data_file):
print("generating train/test data files")
dataset = load_dataset(
"csv", data_files=data_path, delimiter="\t", split="train"
)
print(len(dataset))
dataset = dataset.filter(lambda x: x["chinese"] and x["english"])
datasets = dataset.train_test_split(test_size=0.2)
print(len(dataset))
# Convert to pandas DataFrame
train_df = pd.DataFrame(datasets["train"])
test_df = pd.DataFrame(datasets["test"])
# Save to TSV
train_df.to_csv(train_data_file, sep="\t", index=False)
test_df.to_csv(test_data_file, sep="\t", index=False)
print("loading train/test data files")
datasets = load_dataset(
"csv",
data_files={"train": train_data_file, "test": test_data_file},
delimiter="\t",
)
if tokenizer:
translation_prompt = "Please translate the following Chinese text into English and provide only the translated content, nothing else.\n{}"
def formatting_prompts_func(examples):
inputs = examples["chinese"]
outputs = examples["english"]
messages = [
{
"role": "system",
"content": "You are an expert in translating Chinese to English.",
},
None,
]
model_name = os.getenv("MODEL_NAME")
# if "mistral" in model_name.lower():
# messages = messages[1:]
texts = []
prompts = []
for input, output in zip(inputs, outputs):
prompt = translation_prompt.format(input)
messages[-1] = {"role": "user", "content": prompt}
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompts.append(prompt)
texts.append(prompt + output + tokenizer.eos_token)
return {"text": texts, "prompt": prompts}
datasets = datasets.map(
formatting_prompts_func,
batched=True,
)
print(datasets)
return datasets
def get_metrics(df, max_output_tokens=2048):
metrics_df = pd.DataFrame(df.columns.T)[2:]
metrics_df.rename(columns={0: "model"}, inplace=True)
metrics_df["rpp"] = metrics_df["model"].apply(lambda x: x.split("rpp-")[-1])
metrics_df["model"] = metrics_df["model"].apply(lambda x: x.split("/rpp-")[0])
metrics_df.reset_index(inplace=True)
metrics_df = metrics_df.drop(columns=["index"])
tokenizers = {
model: load_tokenizer(model) for model in metrics_df["model"].unique()
}
meteor = []
bleu_1 = []
rouge_l = []
ews_score = []
repetition_score = []
total_repetitions = []
num_entries_with_max_output_tokens = []
for col in df.columns[2:]:
metrics = calc_metrics(df["english"], df[col], debug=True)
print(f"{col}: {metrics}")
meteor.append(metrics["meteor"])
bleu_1.append(metrics["bleu_scores"]["bleu"])
rouge_l.append(metrics["rouge_scores"]["rougeL"])
df[["ews_score", "repetition_score", "total_repetitions"]] = df[col].apply(
detect_scores
)
ews_score.append(df["ews_score"].mean())
repetition_score.append(df["repetition_score"].mean())
total_repetitions.append(df["total_repetitions"].mean())
df["output_tokens"] = df[col].apply(
lambda x: len(tokenizers[col.split("/rpp")[0]](x)["input_ids"])
)
num_entries_with_max_output_tokens.append(
df["output_tokens"].value_counts().get(max_output_tokens, 0)
)
metrics_df["meteor"] = meteor
metrics_df["bleu_1"] = bleu_1
metrics_df["rouge_l"] = rouge_l
metrics_df["ews_score"] = ews_score
metrics_df["repetition_score"] = ews_score
metrics_df["total_repetitions"] = ews_score
metrics_df["num_entries_with_max_output_tokens"] = (
num_entries_with_max_output_tokens
)
return metrics_df
def plot_metrics(metrics_df, figsize=(14, 5), ylim=(0, 0.44)):
plt.figure(figsize=figsize)
df_melted = pd.melt(
metrics_df, id_vars="model", value_vars=["meteor", "bleu_1", "rouge_l"]
)
barplot = sns.barplot(x="variable", y="value", hue="model", data=df_melted)
# Set different hatches for each model
hatches = ["/", "\\", "|", "-", "+", "x", "o", "O", ".", "*", "//", "\\\\"]
# Create a dictionary to map models to hatches
model_hatches = {
model: hatches[i % len(hatches)]
for i, model in enumerate(metrics_df["model"].unique())
}
# Apply hatches based on the model
num_vars = len(df_melted["variable"].unique())
for i, bar in enumerate(barplot.patches):
model = df_melted["model"].iloc[i // num_vars]
bar.set_hatch(model_hatches[model])
# Manually update legend to match the bar hatches
handles, labels = barplot.get_legend_handles_labels()
for handle, model in zip(handles, metrics_df["model"].unique()):
handle.set_hatch(model_hatches[model])
barplot.set_xticklabels(["METEOR", "BLEU-1", "ROUGE-L"])
for p in barplot.patches:
if p.get_height() == 0:
continue
barplot.annotate(
f"{p.get_height():.2f}",
(p.get_x() + p.get_width() / 2.0, p.get_height()),
ha="center",
va="center",
xytext=(0, 10),
textcoords="offset points",
)
barplot.set(ylim=ylim, ylabel="Scores", xlabel="Metrics")
plt.legend(bbox_to_anchor=(0.5, -0.1), loc="upper center")
plt.show()
def plot_times(perf_df, ylim=0.421):
# Adjusted code to put "train-time" bars in red at the bottom
fig, ax1 = plt.subplots(figsize=(12, 10))
color_train = "tab:red"
color_eval = "orange"
ax1.set_xlabel("Models")
ax1.set_ylabel("Time (mins)")
ax1.set_xticks(range(len(perf_df["model"]))) # Set x-ticks positions
ax1.set_xticklabels(perf_df["model"], rotation=90)
# Plot "train-time" first so it's at the bottom
ax1.bar(
perf_df["model"],
perf_df["train-time(mins)"],
color=color_train,
label="train-time",
)
# Then, plot "eval-time" on top of "train-time"
ax1.bar(
perf_df["model"],
perf_df["eval-time(mins)"],
bottom=perf_df["train-time(mins)"],
color=color_eval,
label="eval-time",
)
ax1.tick_params(axis="y")
ax1.legend(loc="upper left")
if "meteor" in perf_df.columns:
ax2 = ax1.twinx()
color_meteor = "tab:blue"
ax2.set_ylabel("METEOR", color=color_meteor)
ax2.plot(
perf_df["model"],
perf_df["meteor"],
color=color_meteor,
marker="o",
label="meteor",
)
ax2.tick_params(axis="y", labelcolor=color_meteor)
ax2.legend(loc="upper right")
ax2.set_ylim(ax2.get_ylim()[0], ylim)
# Show numbers in bars
for p in ax1.patches:
height = p.get_height()
if height == 0: # Skip bars with height 0
continue
ax1.annotate(
f"{height:.2f}",
(p.get_x() + p.get_width() / 2.0, p.get_y() + height),
ha="center",
va="center",
xytext=(0, -10),
textcoords="offset points",
)
fig.tight_layout()
plt.show()
def translate_via_llm(text):
base_url = os.getenv("OPENAI_BASE_URL") or "http://localhost:8000/v1"
llm = ChatOpenAI(
model="gpt-4o",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
base_url=base_url,
)
prompt = ChatPromptTemplate.from_messages(
[
(
"human",
"Please translate the following Chinese text into English and provide only the translated content, nothing else.\n{input}",
),
]
)
chain = prompt | llm
response = chain.invoke(
{
"input": text,
}
)
return response.content
def translate(text, cache_dict):
if text in cache_dict:
return cache_dict[text]
else:
translated_text = translate_via_llm(text)
cache_dict[text] = translated_text
return translated_text