Tzktz's picture
Upload 7664 files
6fc683c verified
# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link: https://arxiv.org/abs/2305.10855
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser
# Copyright (c) Microsoft Corporation.
# This file provides the inference script.
# ------------------------------------------
import os
import re
import copy
gts = {
'ChineseDrawText': [],
'DrawBenchText': [],
'DrawTextCreative': [],
'LAIONEval4000': [],
'OpenLibraryEval500': [],
'TMDBEval500': [],
}
results = {
'stablediffusion': {'cnt':0, 'p':0, 'r':0, 'f':0, 'acc':0},
'textdiffuser': {'cnt':0, 'p':0, 'r':0, 'f':0, 'acc':0},
'controlnet': {'cnt':0, 'p':0, 'r':0, 'f':0, 'acc':0},
'deepfloyd': {'cnt':0, 'p':0, 'r':0, 'f':0, 'acc':0},
}
def get_key_words(text: str):
words = []
text = text
matches = re.findall(r"'(.*?)'", text) # find the keywords enclosed by ''
if matches:
for match in matches:
words.extend(match.split())
return words
# load gt
files = os.listdir('/path/to/MARIOEval')
for file in files:
lines = open(os.path.join('/path/to/MARIOEval', file, f'{file}.txt')).readlines()
for line in lines:
line = line.strip().lower()
gts[file].append(get_key_words(line))
print(gts['ChineseDrawText'][:10])
def get_p_r_acc(method, pred, gt):
pred = [p.strip().lower() for p in pred]
gt = [g.strip().lower() for g in gt]
pred_orig = copy.deepcopy(pred)
gt_orig = copy.deepcopy(gt)
pred_length = len(pred)
gt_length = len(gt)
for p in pred:
if p in gt_orig:
pred_orig.remove(p)
gt_orig.remove(p)
p = (pred_length - len(pred_orig)) / (pred_length + 1e-8)
r = (gt_length - len(gt_orig)) / (gt_length + 1e-8)
pred_sorted = sorted(pred)
gt_sorted = sorted(gt)
if ''.join(pred_sorted) == ''.join(gt_sorted):
acc = 1
else:
acc = 0
return p, r, acc
files = os.listdir('/path/to/MaskTextSpotterV3/tools/ocr_result')
print(len(files))
for file in files:
method, dataset, prompt_index, image_index = file.strip().split('_')
ocrs = open(os.path.join('/path/to/MaskTextSpotterV3/tools/ocr_result', file)).readlines()
p, r, acc = get_p_r_acc(method, ocrs, gts[dataset][int(prompt_index)])
results[method]['cnt'] += 1
results[method]['p'] += p
results[method]['r'] += r
results[method]['acc'] += acc
for method in results.keys():
results[method]['p'] /= results[method]['cnt']
results[method]['r'] /= results[method]['cnt']
results[method]['f'] = 2 * results[method]['p'] * results[method]['r'] / (results[method]['p'] + results[method]['r'] + 1e-8)
results[method]['acc'] /= results[method]['cnt']
print(results)