OFA-OCR-dedao-demo001 / utils /eval_utils.py
JustinLin610's picture
remove unnecessary files
dd78d66
raw
history blame
2.12 kB
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import string
import math
import json
from itertools import chain
import os
import torch
import torch.distributed as dist
from fairseq import utils
from data import data_utils
from tasks.nlg_tasks.gigaword import fix_tokenization
def get_symbols_to_strip_from_output(generator):
if hasattr(generator, "symbols_to_strip_from_output"):
return generator.symbols_to_strip_from_output
else:
return {generator.bos, generator.eos}
def decode_fn(x, tgt_dict, bpe, generator, tokenizer=None):
x = tgt_dict.string(x.int().cpu(), extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator))
if bpe is not None:
x = bpe.decode(x)
if tokenizer is not None:
x = tokenizer.decode(x)
return x
def eval_ocr(task, generator, models, sample, **kwargs):
gen_out = task.inference_step(generator, models, sample)
hyps, refs, results = [], [], []
for i, sample_id in enumerate(sample["id"].tolist()):
decode_tokens = decode_fn(gen_out[i][0]["tokens"], task.tgt_dict, task.bpe, generator).strip()
hyps.append(decode_tokens.strip().replace(" ", ""))
if sample["target"]:
refs.append(
decode_fn(
utils.strip_pad(sample["target"][i], task.tgt_dict.pad()),
task.tgt_dict, task.bpe, generator
)
.strip()
.replace(" ", "")
)
results.append(
{
"image_id": str(sample_id),
"ocr": decode_tokens.strip().replace(" ", ""),
}
)
if refs:
acc = [1.0 if hyp == ref else 0.0 for hyp, ref in zip(hyps, refs)]
else:
acc = None
return results, acc
def eval_step(task, generator, models, sample, **kwargs):
if task.cfg._name == "ocr":
return eval_ocr(task, generator, models, sample, **kwargs)
else:
raise NotImplementedError