from pycparser.ply.yacc import token from ultralytics import YOLO from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoModelForCausalLM, pipeline, AutoModelForMaskedLM from PIL import Image import numpy as np import pandas as pd from nltk.translate import bleu_score from nltk.translate.bleu_score import SmoothingFunction import torch yolo_weights_path = "final_wts.pt" device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten') trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device) trocr_model.config.num_beams = 1 yolo_model = YOLO(yolo_weights_path).to('mps') unmasker_large = pipeline('fill-mask', model='roberta-large', device=device) roberta_model = AutoModelForMaskedLM.from_pretrained("roberta-large").to(device) print(f'TrOCR and YOLO Models loaded on {device}') ------------------------------------------------------- CONFIDENCE_THRESHOLD = 0.72 BLEU_THRESHOLD = 0.6 def inference(image_path, debug=False, return_texts='final'): def get_cropped_images(image_path): results = yolo_model(image_path, save=True) patches = [] ys = [] for box in sorted(results[0].boxes, key=lambda x: x.xywh[0][1]): image = Image.open(image_path).convert("RGB") x_center, y_center, w, h = box.xywh[0].cpu().numpy() x, y = x_center - w / 2, y_center - h / 2 cropped_image = image.crop((x, y, x + w, y + h)) patches.append(cropped_image) ys.append(y) bounding_box_path = results[0].save_dir + results[0].path[results[0].path.rindex('/'):-4] + '.jpg' return patches, ys, bounding_box_path def get_model_output(images): pixel_values = processor(images=images, return_tensors="pt").pixel_values.to(device) output = trocr_model.generate(pixel_values, return_dict_in_generate=True, output_logits=True, max_new_tokens=30) generated_texts = processor.batch_decode(output.sequences, skip_special_tokens=True) generated_tokens = [processor.tokenizer.convert_ids_to_tokens(seq) for seq in output.sequences] stacked_logits = torch.stack(output.logits, dim=1) return generated_texts, stacked_logits, generated_tokens def get_scores(logits): scores = logits.softmax(-1).max(-1).values.mean(-1) return scores def post_process_texts(generated_texts): for i in range(len(generated_texts)): if len(generated_texts[i]) > 2 and generated_texts[i][:2] == '# ': generated_texts[i] = generated_texts[i][2:] if len(generated_texts[i]) > 2 and generated_texts[i][-2:] == ' #': generated_texts[i] = generated_texts[i][:-2] return generated_texts def get_qualified_texts(generated_texts, scores, y, logits, tokens): qualified_texts = [] for text, score, y_i, logits_i, tokens_i in zip(generated_texts, scores, y, logits, tokens): if score > CONFIDENCE_THRESHOLD: qualified_texts.append({ 'text': text, 'score': score, 'y': y_i, 'logits': logits_i, 'tokens': tokens_i }) return qualified_texts def get_adjacent_bleu_scores(qualified_texts): def get_bleu_score(hypothesis, references): weights = [0.5, 0.5] smoothing = SmoothingFunction() return bleu_score.sentence_bleu(references, hypothesis, weights=weights, smoothing_function=smoothing.method1) for i in range(len(qualified_texts)): hyp = qualified_texts[i]['text'].split() bleu = 0 if i < len(qualified_texts) - 1: ref = qualified_texts[i + 1]['text'].split() bleu = get_bleu_score(hyp, [ref]) qualified_texts[i]['bleu'] = bleu return qualified_texts def remove_overlapping_texts(qualified_texts): final_texts = [] new = True for i in range(len(qualified_texts)): if new: final_texts.append(qualified_texts[i]) else: if final_texts[-1]['score'] < qualified_texts[i]['score']: final_texts[-1] = qualified_texts[i] new = qualified_texts[i]['bleu'] < BLEU_THRESHOLD return final_texts cropped_images, y, bounding_box_path = get_cropped_images(image_path) if debug: print('Number of cropped images:', len(cropped_images)) generated_texts, logits, gen_tokens = get_model_output(cropped_images) normalised_scores = get_scores(logits) if return_texts == 'generated': return pd.DataFrame({ 'text': generated_texts, 'score': normalised_scores, 'y': y, }) generated_texts = post_process_texts(generated_texts) if return_texts == 'post_processed': return pd.DataFrame({ 'text': generated_texts, 'score': normalised_scores, 'y': y }) qualified_texts = get_qualified_texts(generated_texts, normalised_scores, y, logits, gen_tokens) if return_texts == 'qualified': return pd.DataFrame(qualified_texts) qualified_texts = get_adjacent_bleu_scores(qualified_texts) if return_texts == 'qualified_with_bleu': return pd.DataFrame(qualified_texts) final_texts = remove_overlapping_texts(qualified_texts) final_texts_df = pd.DataFrame(final_texts, columns=['text', 'score', 'y']) final_tokens = [text['tokens'] for text in final_texts] final_logits = [text['logits'] for text in final_texts] if return_texts == 'final': return final_texts_df return final_texts_df, bounding_box_path, final_tokens, final_logits, generated_texts image_path = "raw_dataset/g06-037h.png" df, bounding_path, tokens, logits, gen_texts = inference(image_path, debug=False, return_texts='final_v2') ---------------------------------------------------------------- def get_new_logits(tokens): inputs = tokens.reshape(1, -1) # Get the logits from the model with torch.no_grad(): outputs = roberta_model(input_ids=inputs, attention_mask=torch.ones(inputs.shape).to(device)) logits = outputs.logits logits_flattened = logits.reshape(-1, slogits.shape[-1]) print(processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True)) return logits.reshape(tokens.shape + (logits.shape[-1],)) slogits = torch.stack([logit for logit in logits], dim=0) tokens = slogits.argmax(-1) confidence = slogits.softmax(-1).max(-1).values indices = torch.where(confidence < 0.5) # put 50264(mask) when confidence < 0.5 for i, j in zip(indices[0], indices[1]): if i != 6: continue tokens[i, j] = torch.tensor(50264) new_logits = get_new_logits(tokens) ---------------------------------------------------------------- for i, j in zip(indices[0], indices[1]): slogits[i, j] = slogits[i, j] * 0.1 + new_logits[i, j] * 0.5 logits_flattened = slogits.reshape(-1, slogits.shape[-1]) processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True)