Spaces:
Running
Running
import torch | |
import evaluate | |
import random | |
from unimernet.common.registry import registry | |
from unimernet.tasks.base_task import BaseTask | |
from unimernet.common.dist_utils import main_process | |
import os.path as osp | |
import json | |
import numpy as np | |
from torchtext.data import metrics | |
from rapidfuzz.distance import Levenshtein | |
class UniMERNet_Train(BaseTask): | |
def __init__(self, temperature, do_sample, top_p, evaluate, report_metric=True, agg_metric="edit_distance"): | |
super(UniMERNet_Train, self).__init__() | |
self.temperature = temperature | |
self.do_sample = do_sample | |
self.top_p = top_p | |
self.evaluate = evaluate | |
self.agg_metric = agg_metric | |
self.report_metric = report_metric | |
def setup_task(cls, cfg): | |
run_cfg = cfg.run_cfg | |
generate_cfg = run_cfg.generate_cfg | |
temperature = generate_cfg.get('temperature', .2) | |
do_sample = generate_cfg.get("do_sample", False) | |
top_p = generate_cfg.get("top_p", 0.95) | |
evaluate = run_cfg.evaluate | |
report_metric = run_cfg.get("report_metric", True) | |
agg_metric = run_cfg.get("agg_metric", "edit_distance") | |
return cls( | |
temperature=temperature, | |
do_sample=do_sample, | |
top_p=top_p, | |
evaluate=evaluate, | |
report_metric=report_metric, | |
agg_metric=agg_metric, | |
) | |
def valid_step(self, model, samples): | |
results = [] | |
image, text = samples["image"], samples["text_input"] | |
preds = model.generate( | |
samples, | |
temperature=self.temperature, | |
do_sample=self.do_sample, | |
top_p=self.top_p | |
) | |
pred_tokens = preds["pred_tokens"] | |
pred_strs = preds["pred_str"] | |
pred_ids = preds["pred_ids"] # [b, n-1] | |
truth_inputs = model.tokenizer.tokenize(text) | |
truth_ids = truth_inputs["input_ids"][:, 1:] | |
truth_tokens = model.tokenizer.detokenize(truth_inputs["input_ids"]) | |
truth_strs = model.tokenizer.token2str(truth_inputs["input_ids"]) | |
ids = samples["id"] | |
for pred_token, pred_str, pred_id, truth_token, truth_str, truth_id, id_ in zip(pred_tokens, pred_strs, | |
pred_ids, truth_tokens, | |
truth_strs, truth_ids, ids): | |
pred_id = pred_id.tolist() | |
truth_id = truth_id.tolist() | |
shape_diff = len(pred_id) - len(truth_id) | |
if shape_diff < 0: | |
pred_id = pred_id + [model.tokenizer.pad_token_id] * (-shape_diff) | |
else: | |
truth_id = truth_id + [model.tokenizer.pad_token_id] * shape_diff | |
pred_id, truth_id = torch.LongTensor(pred_id), torch.LongTensor(truth_id) | |
mask = torch.logical_or(pred_id != model.tokenizer.pad_token_id, truth_id != model.tokenizer.pad_token_id) | |
tok_acc = (pred_id == truth_id)[mask].float().mean().item() | |
this_item = { | |
"pred_token": pred_token, | |
"pred_str": pred_str, | |
"truth_str": truth_str, | |
"truth_token": truth_token, | |
"token_acc": tok_acc, | |
"id": id_ | |
} | |
results.append(this_item) | |
return results | |
def after_evaluation(self, val_result, split_name, epoch, **kwargs): | |
eval_result_file = self.save_result( | |
result=val_result, | |
result_dir=registry.get_path("result_dir"), | |
filename="{}_epoch{}".format(split_name, epoch), | |
remove_duplicate="id", | |
) | |
if self.report_metric: | |
metrics = self._report_metrics( | |
eval_result_file=eval_result_file, split_name=split_name | |
) | |
else: | |
metrics = {"agg_metrics": 0.0} | |
return metrics | |
def _report_metrics(self, eval_result_file, split_name): | |
with open(eval_result_file) as f: | |
results = json.load(f) | |
edit_dists = [] | |
all_pred_tokens = [] | |
all_truth_tokens = [] | |
all_pred_strs = [] | |
all_truth_strs = [] | |
token_accs = [] | |
for result in results: | |
pred_token, pred_str, truth_token, truth_str, tok_acc = result["pred_token"], result["pred_str"], result[ | |
"truth_token"], result["truth_str"], result["token_acc"] | |
if len(truth_str) > 0: | |
norm_edit_dist = Levenshtein.normalized_distance(pred_str, truth_str) | |
edit_dists.append(norm_edit_dist) | |
all_pred_tokens.append(pred_token) | |
all_truth_tokens.append([truth_token]) | |
all_pred_strs.append(pred_str) | |
all_truth_strs.append(truth_str) | |
token_accs.append(tok_acc) | |
# bleu_score = metrics.bleu_score(all_pred_tokens, all_truth_tokens) | |
bleu = evaluate.load("bleu", keep_in_memory=True, experiment_id=random.randint(1, 1e8)) | |
bleu_results = bleu.compute(predictions=all_pred_strs, references=all_truth_strs) | |
bleu_score = bleu_results['bleu'] | |
edit_distance = np.mean(edit_dists) | |
token_accuracy = np.mean(token_accs) | |
eval_ret = {"bleu": bleu_score, "edit_distance": edit_distance, "token_accuracy": token_accuracy} | |
log_stats = {split_name: {k: v for k, v in eval_ret.items()}} | |
with open( | |
osp.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
) as f: | |
f.write(json.dumps(log_stats) + "\n") | |
coco_res = {k: v for k, v in eval_ret.items()} | |
# agg_metrics = sum([v for v in eval_ret.values()]) | |
if "edit" in self.agg_metric.lower(): # edit_distance | |
agg_metrics = (1 - edit_distance) * 100 | |
elif "bleu" in self.agg_metric.lower(): # bleu_score | |
agg_metrics = bleu_score * 100 | |
elif "token" in self.agg_metric.lower(): # token_accuracy | |
agg_metrics = token_accuracy * 100 | |
else: | |
raise ValueError(f"Invalid metrics: '{self.agg_metric}'") | |
coco_res["agg_metrics"] = agg_metrics | |
return coco_res | |