Spaces:
Build error
Build error
File size: 4,264 Bytes
cb433d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import argparse
import logging
import os
import glob
import tqdm
import torch
import PIL
import cv2
import numpy as np
import torch.nn.functional as F
from torchvision import transforms
from utils import Config, Logger, CharsetMapper
def get_model(config):
import importlib
names = config.model_name.split('.')
module_name, class_name = '.'.join(names[:-1]), names[-1]
cls = getattr(importlib.import_module(module_name), class_name)
model = cls(config)
logging.info(model)
model = model.eval()
return model
def preprocess(img, width, height):
img = cv2.resize(np.array(img), (width, height))
img = transforms.ToTensor()(img).unsqueeze(0)
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
return (img-mean[...,None,None]) / std[...,None,None]
def postprocess(output, charset, model_eval):
def _get_output(last_output, model_eval):
if isinstance(last_output, (tuple, list)):
for res in last_output:
if res['name'] == model_eval: output = res
else: output = last_output
return output
def _decode(logit):
""" Greed decode """
out = F.softmax(logit, dim=2)
pt_text, pt_scores, pt_lengths = [], [], []
for o in out:
text = charset.get_text(o.argmax(dim=1), padding=False, trim=False)
text = text.split(charset.null_char)[0] # end at end-token
pt_text.append(text)
pt_scores.append(o.max(dim=1)[0])
pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token
return pt_text, pt_scores, pt_lengths
output = _get_output(output, model_eval)
logits, pt_lengths = output['logits'], output['pt_lengths']
pt_text, pt_scores, pt_lengths_ = _decode(logits)
return pt_text, pt_scores, pt_lengths_
def load(model, file, device=None, strict=True):
if device is None: device = 'cpu'
elif isinstance(device, int): device = torch.device('cuda', device)
assert os.path.isfile(file)
state = torch.load(file, map_location=device)
if set(state.keys()) == {'model', 'opt'}:
state = state['model']
model.load_state_dict(state, strict=strict)
return model
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='configs/train_abinet.yaml',
help='path to config file')
parser.add_argument('--input', type=str, default='figs/test')
parser.add_argument('--cuda', type=int, default=-1)
parser.add_argument('--checkpoint', type=str, default='workdir/train-abinet/best-train-abinet.pth')
parser.add_argument('--model_eval', type=str, default='alignment',
choices=['alignment', 'vision', 'language'])
args = parser.parse_args()
config = Config(args.config)
if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
if args.model_eval is not None: config.model_eval = args.model_eval
config.global_phase = 'test'
config.model_vision_checkpoint, config.model_language_checkpoint = None, None
device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}'
Logger.init(config.global_workdir, config.global_name, config.global_phase)
Logger.enable_file()
logging.info(config)
logging.info('Construct model.')
model = get_model(config).to(device)
model = load(model, config.model_checkpoint, device=device)
charset = CharsetMapper(filename=config.dataset_charset_path,
max_length=config.dataset_max_length + 1)
if os.path.isdir(args.input):
paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)]
else:
paths = glob.glob(os.path.expanduser(args.input))
assert paths, "The input path(s) was not found"
paths = sorted(paths)
for path in tqdm.tqdm(paths):
img = PIL.Image.open(path).convert('RGB')
img = preprocess(img, config.dataset_image_width, config.dataset_image_height)
img = img.to(device)
res = model(img)
pt_text, _, __ = postprocess(res, charset, config.model_eval)
logging.info(f'{path}: {pt_text[0]}')
if __name__ == '__main__':
main()
|