madhavkotecha commited on
Commit
acec273
1 Parent(s): 97c40f1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pycparser.ply.yacc import token
2
+ from ultralytics import YOLO
3
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoModelForCausalLM, pipeline, AutoModelForMaskedLM
4
+ from PIL import Image
5
+ import numpy as np
6
+ import pandas as pd
7
+ from nltk.translate import bleu_score
8
+ from nltk.translate.bleu_score import SmoothingFunction
9
+ import torch
10
+
11
+ yolo_weights_path = "final_wts.pt"
12
+
13
+ device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
14
+
15
+ processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
16
+ trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)
17
+ trocr_model.config.num_beams = 1
18
+
19
+ yolo_model = YOLO(yolo_weights_path).to('mps')
20
+ unmasker_large = pipeline('fill-mask', model='roberta-large', device=device)
21
+ roberta_model = AutoModelForMaskedLM.from_pretrained("roberta-large").to(device)
22
+
23
+ print(f'TrOCR and YOLO Models loaded on {device}')
24
+
25
+
26
+ -------------------------------------------------------
27
+
28
+
29
+ CONFIDENCE_THRESHOLD = 0.72
30
+ BLEU_THRESHOLD = 0.6
31
+
32
+
33
+ def inference(image_path, debug=False, return_texts='final'):
34
+ def get_cropped_images(image_path):
35
+ results = yolo_model(image_path, save=True)
36
+ patches = []
37
+ ys = []
38
+ for box in sorted(results[0].boxes, key=lambda x: x.xywh[0][1]):
39
+ image = Image.open(image_path).convert("RGB")
40
+ x_center, y_center, w, h = box.xywh[0].cpu().numpy()
41
+ x, y = x_center - w / 2, y_center - h / 2
42
+ cropped_image = image.crop((x, y, x + w, y + h))
43
+ patches.append(cropped_image)
44
+ ys.append(y)
45
+ bounding_box_path = results[0].save_dir + results[0].path[results[0].path.rindex('/'):-4] + '.jpg'
46
+ return patches, ys, bounding_box_path
47
+
48
+ def get_model_output(images):
49
+ pixel_values = processor(images=images, return_tensors="pt").pixel_values.to(device)
50
+ output = trocr_model.generate(pixel_values, return_dict_in_generate=True, output_logits=True, max_new_tokens=30)
51
+ generated_texts = processor.batch_decode(output.sequences, skip_special_tokens=True)
52
+ generated_tokens = [processor.tokenizer.convert_ids_to_tokens(seq) for seq in output.sequences]
53
+ stacked_logits = torch.stack(output.logits, dim=1)
54
+ return generated_texts, stacked_logits, generated_tokens
55
+
56
+ def get_scores(logits):
57
+ scores = logits.softmax(-1).max(-1).values.mean(-1)
58
+ return scores
59
+
60
+ def post_process_texts(generated_texts):
61
+ for i in range(len(generated_texts)):
62
+ if len(generated_texts[i]) > 2 and generated_texts[i][:2] == '# ':
63
+ generated_texts[i] = generated_texts[i][2:]
64
+
65
+ if len(generated_texts[i]) > 2 and generated_texts[i][-2:] == ' #':
66
+ generated_texts[i] = generated_texts[i][:-2]
67
+ return generated_texts
68
+
69
+ def get_qualified_texts(generated_texts, scores, y, logits, tokens):
70
+ qualified_texts = []
71
+ for text, score, y_i, logits_i, tokens_i in zip(generated_texts, scores, y, logits, tokens):
72
+ if score > CONFIDENCE_THRESHOLD:
73
+ qualified_texts.append({
74
+ 'text': text,
75
+ 'score': score,
76
+ 'y': y_i,
77
+ 'logits': logits_i,
78
+ 'tokens': tokens_i
79
+ })
80
+ return qualified_texts
81
+
82
+ def get_adjacent_bleu_scores(qualified_texts):
83
+ def get_bleu_score(hypothesis, references):
84
+ weights = [0.5, 0.5]
85
+ smoothing = SmoothingFunction()
86
+ return bleu_score.sentence_bleu(references, hypothesis, weights=weights,
87
+ smoothing_function=smoothing.method1)
88
+
89
+ for i in range(len(qualified_texts)):
90
+ hyp = qualified_texts[i]['text'].split()
91
+ bleu = 0
92
+ if i < len(qualified_texts) - 1:
93
+ ref = qualified_texts[i + 1]['text'].split()
94
+ bleu = get_bleu_score(hyp, [ref])
95
+ qualified_texts[i]['bleu'] = bleu
96
+ return qualified_texts
97
+
98
+ def remove_overlapping_texts(qualified_texts):
99
+ final_texts = []
100
+ new = True
101
+ for i in range(len(qualified_texts)):
102
+ if new:
103
+ final_texts.append(qualified_texts[i])
104
+ else:
105
+ if final_texts[-1]['score'] < qualified_texts[i]['score']:
106
+ final_texts[-1] = qualified_texts[i]
107
+ new = qualified_texts[i]['bleu'] < BLEU_THRESHOLD
108
+ return final_texts
109
+
110
+ cropped_images, y, bounding_box_path = get_cropped_images(image_path)
111
+ if debug:
112
+ print('Number of cropped images:', len(cropped_images))
113
+ generated_texts, logits, gen_tokens = get_model_output(cropped_images)
114
+ normalised_scores = get_scores(logits)
115
+ if return_texts == 'generated':
116
+ return pd.DataFrame({
117
+ 'text': generated_texts,
118
+ 'score': normalised_scores,
119
+ 'y': y,
120
+ })
121
+ generated_texts = post_process_texts(generated_texts)
122
+ if return_texts == 'post_processed':
123
+ return pd.DataFrame({
124
+ 'text': generated_texts,
125
+ 'score': normalised_scores,
126
+ 'y': y
127
+ })
128
+ qualified_texts = get_qualified_texts(generated_texts, normalised_scores, y, logits, gen_tokens)
129
+ if return_texts == 'qualified':
130
+ return pd.DataFrame(qualified_texts)
131
+ qualified_texts = get_adjacent_bleu_scores(qualified_texts)
132
+ if return_texts == 'qualified_with_bleu':
133
+ return pd.DataFrame(qualified_texts)
134
+ final_texts = remove_overlapping_texts(qualified_texts)
135
+ final_texts_df = pd.DataFrame(final_texts, columns=['text', 'score', 'y'])
136
+ final_tokens = [text['tokens'] for text in final_texts]
137
+ final_logits = [text['logits'] for text in final_texts]
138
+ if return_texts == 'final':
139
+ return final_texts_df
140
+
141
+ return final_texts_df, bounding_box_path, final_tokens, final_logits, generated_texts
142
+
143
+
144
+ image_path = "raw_dataset/g06-037h.png"
145
+ df, bounding_path, tokens, logits, gen_texts = inference(image_path, debug=False, return_texts='final_v2')
146
+
147
+ ----------------------------------------------------------------
148
+
149
+
150
+ def get_new_logits(tokens):
151
+ inputs = tokens.reshape(1, -1)
152
+ # Get the logits from the model
153
+ with torch.no_grad():
154
+ outputs = roberta_model(input_ids=inputs, attention_mask=torch.ones(inputs.shape).to(device))
155
+ logits = outputs.logits
156
+
157
+
158
+ logits_flattened = logits.reshape(-1, slogits.shape[-1])
159
+ print(processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True))
160
+ return logits.reshape(tokens.shape + (logits.shape[-1],))
161
+
162
+
163
+ slogits = torch.stack([logit for logit in logits], dim=0)
164
+ tokens = slogits.argmax(-1)
165
+ confidence = slogits.softmax(-1).max(-1).values
166
+ indices = torch.where(confidence < 0.5)
167
+ # put 50264(mask) when confidence < 0.5
168
+ for i, j in zip(indices[0], indices[1]):
169
+ if i != 6:
170
+ continue
171
+ tokens[i, j] = torch.tensor(50264)
172
+
173
+ new_logits = get_new_logits(tokens)
174
+
175
+
176
+ ----------------------------------------------------------------
177
+
178
+
179
+ for i, j in zip(indices[0], indices[1]):
180
+ slogits[i, j] = slogits[i, j] * 0.1 + new_logits[i, j] * 0.5
181
+
182
+ logits_flattened = slogits.reshape(-1, slogits.shape[-1])
183
+ processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True)
184
+
185
+