Spaces:
Running
Running
from pycparser.ply.yacc import token | |
from ultralytics import YOLO | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoModelForCausalLM, pipeline, AutoModelForMaskedLM | |
from PIL import Image | |
import numpy as np | |
import pandas as pd | |
from nltk.translate import bleu_score | |
from nltk.translate.bleu_score import SmoothingFunction | |
import torch | |
yolo_weights_path = "final_wts.pt" | |
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' | |
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten') | |
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device) | |
trocr_model.config.num_beams = 1 | |
yolo_model = YOLO(yolo_weights_path).to('mps') | |
unmasker_large = pipeline('fill-mask', model='roberta-large', device=device) | |
roberta_model = AutoModelForMaskedLM.from_pretrained("roberta-large").to(device) | |
print(f'TrOCR and YOLO Models loaded on {device}') | |
------------------------------------------------------- | |
CONFIDENCE_THRESHOLD = 0.72 | |
BLEU_THRESHOLD = 0.6 | |
def inference(image_path, debug=False, return_texts='final'): | |
def get_cropped_images(image_path): | |
results = yolo_model(image_path, save=True) | |
patches = [] | |
ys = [] | |
for box in sorted(results[0].boxes, key=lambda x: x.xywh[0][1]): | |
image = Image.open(image_path).convert("RGB") | |
x_center, y_center, w, h = box.xywh[0].cpu().numpy() | |
x, y = x_center - w / 2, y_center - h / 2 | |
cropped_image = image.crop((x, y, x + w, y + h)) | |
patches.append(cropped_image) | |
ys.append(y) | |
bounding_box_path = results[0].save_dir + results[0].path[results[0].path.rindex('/'):-4] + '.jpg' | |
return patches, ys, bounding_box_path | |
def get_model_output(images): | |
pixel_values = processor(images=images, return_tensors="pt").pixel_values.to(device) | |
output = trocr_model.generate(pixel_values, return_dict_in_generate=True, output_logits=True, max_new_tokens=30) | |
generated_texts = processor.batch_decode(output.sequences, skip_special_tokens=True) | |
generated_tokens = [processor.tokenizer.convert_ids_to_tokens(seq) for seq in output.sequences] | |
stacked_logits = torch.stack(output.logits, dim=1) | |
return generated_texts, stacked_logits, generated_tokens | |
def get_scores(logits): | |
scores = logits.softmax(-1).max(-1).values.mean(-1) | |
return scores | |
def post_process_texts(generated_texts): | |
for i in range(len(generated_texts)): | |
if len(generated_texts[i]) > 2 and generated_texts[i][:2] == '# ': | |
generated_texts[i] = generated_texts[i][2:] | |
if len(generated_texts[i]) > 2 and generated_texts[i][-2:] == ' #': | |
generated_texts[i] = generated_texts[i][:-2] | |
return generated_texts | |
def get_qualified_texts(generated_texts, scores, y, logits, tokens): | |
qualified_texts = [] | |
for text, score, y_i, logits_i, tokens_i in zip(generated_texts, scores, y, logits, tokens): | |
if score > CONFIDENCE_THRESHOLD: | |
qualified_texts.append({ | |
'text': text, | |
'score': score, | |
'y': y_i, | |
'logits': logits_i, | |
'tokens': tokens_i | |
}) | |
return qualified_texts | |
def get_adjacent_bleu_scores(qualified_texts): | |
def get_bleu_score(hypothesis, references): | |
weights = [0.5, 0.5] | |
smoothing = SmoothingFunction() | |
return bleu_score.sentence_bleu(references, hypothesis, weights=weights, | |
smoothing_function=smoothing.method1) | |
for i in range(len(qualified_texts)): | |
hyp = qualified_texts[i]['text'].split() | |
bleu = 0 | |
if i < len(qualified_texts) - 1: | |
ref = qualified_texts[i + 1]['text'].split() | |
bleu = get_bleu_score(hyp, [ref]) | |
qualified_texts[i]['bleu'] = bleu | |
return qualified_texts | |
def remove_overlapping_texts(qualified_texts): | |
final_texts = [] | |
new = True | |
for i in range(len(qualified_texts)): | |
if new: | |
final_texts.append(qualified_texts[i]) | |
else: | |
if final_texts[-1]['score'] < qualified_texts[i]['score']: | |
final_texts[-1] = qualified_texts[i] | |
new = qualified_texts[i]['bleu'] < BLEU_THRESHOLD | |
return final_texts | |
cropped_images, y, bounding_box_path = get_cropped_images(image_path) | |
if debug: | |
print('Number of cropped images:', len(cropped_images)) | |
generated_texts, logits, gen_tokens = get_model_output(cropped_images) | |
normalised_scores = get_scores(logits) | |
if return_texts == 'generated': | |
return pd.DataFrame({ | |
'text': generated_texts, | |
'score': normalised_scores, | |
'y': y, | |
}) | |
generated_texts = post_process_texts(generated_texts) | |
if return_texts == 'post_processed': | |
return pd.DataFrame({ | |
'text': generated_texts, | |
'score': normalised_scores, | |
'y': y | |
}) | |
qualified_texts = get_qualified_texts(generated_texts, normalised_scores, y, logits, gen_tokens) | |
if return_texts == 'qualified': | |
return pd.DataFrame(qualified_texts) | |
qualified_texts = get_adjacent_bleu_scores(qualified_texts) | |
if return_texts == 'qualified_with_bleu': | |
return pd.DataFrame(qualified_texts) | |
final_texts = remove_overlapping_texts(qualified_texts) | |
final_texts_df = pd.DataFrame(final_texts, columns=['text', 'score', 'y']) | |
final_tokens = [text['tokens'] for text in final_texts] | |
final_logits = [text['logits'] for text in final_texts] | |
if return_texts == 'final': | |
return final_texts_df | |
return final_texts_df, bounding_box_path, final_tokens, final_logits, generated_texts | |
image_path = "raw_dataset/g06-037h.png" | |
df, bounding_path, tokens, logits, gen_texts = inference(image_path, debug=False, return_texts='final_v2') | |
---------------------------------------------------------------- | |
def get_new_logits(tokens): | |
inputs = tokens.reshape(1, -1) | |
# Get the logits from the model | |
with torch.no_grad(): | |
outputs = roberta_model(input_ids=inputs, attention_mask=torch.ones(inputs.shape).to(device)) | |
logits = outputs.logits | |
logits_flattened = logits.reshape(-1, slogits.shape[-1]) | |
print(processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True)) | |
return logits.reshape(tokens.shape + (logits.shape[-1],)) | |
slogits = torch.stack([logit for logit in logits], dim=0) | |
tokens = slogits.argmax(-1) | |
confidence = slogits.softmax(-1).max(-1).values | |
indices = torch.where(confidence < 0.5) | |
# put 50264(mask) when confidence < 0.5 | |
for i, j in zip(indices[0], indices[1]): | |
if i != 6: | |
continue | |
tokens[i, j] = torch.tensor(50264) | |
new_logits = get_new_logits(tokens) | |
---------------------------------------------------------------- | |
for i, j in zip(indices[0], indices[1]): | |
slogits[i, j] = slogits[i, j] * 0.1 + new_logits[i, j] * 0.5 | |
logits_flattened = slogits.reshape(-1, slogits.shape[-1]) | |
processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True) | |