madhavkotecha commited on
Commit
b7dfd46
·
verified ·
1 Parent(s): fa9ebe1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -73
app.py CHANGED
@@ -1,13 +1,13 @@
1
- from pycparser.ply.yacc import token
2
  from ultralytics import YOLO
3
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoModelForCausalLM, pipeline, AutoModelForMaskedLM
4
  from PIL import Image
5
  import numpy as np
6
  import pandas as pd
 
7
  from nltk.translate import bleu_score
8
  from nltk.translate.bleu_score import SmoothingFunction
9
  import torch
10
- import gradio as gr
11
 
12
  yolo_weights_path = "final_wts.pt"
13
 
@@ -15,15 +15,16 @@ device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is
15
 
16
  processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
17
  trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)
18
- trocr_model.config.num_beams = 1
19
 
20
- yolo_model = YOLO(yolo_weights_path).to('cpu')
21
- unmasker_large = pipeline('fill-mask', model='roberta-large', device=device)
22
  roberta_model = AutoModelForMaskedLM.from_pretrained("roberta-large").to(device)
23
 
24
- print(f'TrOCR and YOLO Models loaded on {device}')
25
 
 
26
 
 
 
27
 
28
 
29
  CONFIDENCE_THRESHOLD = 0.72
@@ -61,7 +62,7 @@ def inference(image_path, debug=False, return_texts='final'):
61
  for i in range(len(generated_texts)):
62
  if len(generated_texts[i]) > 2 and generated_texts[i][:2] == '# ':
63
  generated_texts[i] = generated_texts[i][2:]
64
-
65
  if len(generated_texts[i]) > 2 and generated_texts[i][-2:] == ' #':
66
  generated_texts[i] = generated_texts[i][:-2]
67
  return generated_texts
@@ -107,11 +108,29 @@ def inference(image_path, debug=False, return_texts='final'):
107
  new = qualified_texts[i]['bleu'] < BLEU_THRESHOLD
108
  return final_texts
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  cropped_images, y, bounding_box_path = get_cropped_images(image_path)
111
  if debug:
112
  print('Number of cropped images:', len(cropped_images))
113
  generated_texts, logits, gen_tokens = get_model_output(cropped_images)
114
  normalised_scores = get_scores(logits)
 
 
 
115
  if return_texts == 'generated':
116
  return pd.DataFrame({
117
  'text': generated_texts,
@@ -133,81 +152,47 @@ def inference(image_path, debug=False, return_texts='final'):
133
  return pd.DataFrame(qualified_texts)
134
  final_texts = remove_overlapping_texts(qualified_texts)
135
  final_texts_df = pd.DataFrame(final_texts, columns=['text', 'score', 'y'])
136
- final_tokens = [text['tokens'] for text in final_texts]
137
  final_logits = [text['logits'] for text in final_texts]
 
 
 
138
  if return_texts == 'final':
139
  return final_texts_df
140
-
141
- return final_texts_df, bounding_box_path, final_tokens, final_logits, generated_texts
142
-
143
-
144
- # image_path = "raw_dataset/g06-037h.png"
145
- # df, bounding_path, tokens, logits, gen_texts = inference(image_path, debug=False, return_texts='final_v2')
146
-
147
-
148
-
149
- def get_new_logits(tokens):
150
- inputs = tokens.reshape(1, -1)
151
- # Get the logits from the model
152
- with torch.no_grad():
153
- outputs = roberta_model(input_ids=inputs, attention_mask=torch.ones(inputs.shape).to(device))
154
- logits = outputs.logits
155
-
156
-
157
- logits_flattened = logits.reshape(-1, slogits.shape[-1])
158
- print(processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True))
159
- return logits.reshape(tokens.shape + (logits.shape[-1],))
160
-
161
-
162
- slogits = torch.stack([logit for logit in logits], dim=0)
163
- tokens = slogits.argmax(-1)
164
- confidence = slogits.softmax(-1).max(-1).values
165
- indices = torch.where(confidence < 0.5)
166
- # put 50264(mask) when confidence < 0.5
167
- for i, j in zip(indices[0], indices[1]):
168
- if i != 6:
169
- continue
170
- tokens[i, j] = torch.tensor(50264)
171
-
172
- new_logits = get_new_logits(tokens)
173
-
174
-
175
-
176
-
177
- for i, j in zip(indices[0], indices[1]):
178
- slogits[i, j] = slogits[i, j] * 0.1 + new_logits[i, j] * 0.5
179
-
180
- logits_flattened = slogits.reshape(-1, slogits.shape[-1])
181
- processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True)
182
 
 
 
 
 
183
 
 
184
 
185
 
186
 
187
- def gradio_inference(image_path):
188
- """
189
- Function to handle inference and output the generated texts and final processed texts.
190
- """
191
- df, bounding_path, tokens, logits, gen_texts = inference(image_path, debug=False, return_texts='final_v2')
192
-
193
- # Convert the DataFrame for final texts to a readable format
194
- final_texts = df.to_string(index=False)
195
-
196
- # Convert the list of generated texts into a readable string
197
- gen_texts_output = '\n'.join(gen_texts)
198
-
199
- return gen_texts_output, final_texts
200
 
201
- image_input = gr.inputs.Image(type="filepath", label="Upload Image")
202
- generated_output = gr.outputs.Textbox(label="Generated Texts")
203
- final_output = gr.outputs.Textbox(label="Final Processed Texts")
 
 
 
 
 
 
 
204
 
 
205
  interface = gr.Interface(
206
- fn=gradio_inference,
207
- inputs=image_input,
208
- outputs=[generated_output, final_output],
209
- title="OCR using LLMs",
210
- description="Upload an image and get generated and final processed texts",
 
 
 
 
211
  )
212
 
213
- interface.launch()
 
 
 
1
+ import gradio as gr
2
  from ultralytics import YOLO
3
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoModelForMaskedLM
4
  from PIL import Image
5
  import numpy as np
6
  import pandas as pd
7
+ import tempfile
8
  from nltk.translate import bleu_score
9
  from nltk.translate.bleu_score import SmoothingFunction
10
  import torch
 
11
 
12
  yolo_weights_path = "final_wts.pt"
13
 
 
15
 
16
  processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
17
  trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)
18
+ trocr_model.config.num_beams = 2
19
 
20
+ yolo_model = YOLO(yolo_weights_path).to(device)
 
21
  roberta_model = AutoModelForMaskedLM.from_pretrained("roberta-large").to(device)
22
 
 
23
 
24
+ print(f'TrOCR, YOLO and Roberta Models loaded on {device}')
25
 
26
+ CONFIDENCE_THRESHOLD = 0.72
27
+ BLEU_THRESHOLD = 0.6
28
 
29
 
30
  CONFIDENCE_THRESHOLD = 0.72
 
62
  for i in range(len(generated_texts)):
63
  if len(generated_texts[i]) > 2 and generated_texts[i][:2] == '# ':
64
  generated_texts[i] = generated_texts[i][2:]
65
+
66
  if len(generated_texts[i]) > 2 and generated_texts[i][-2:] == ' #':
67
  generated_texts[i] = generated_texts[i][:-2]
68
  return generated_texts
 
108
  new = qualified_texts[i]['bleu'] < BLEU_THRESHOLD
109
  return final_texts
110
 
111
+ def get_lm_logits(ocr_tokens, confidence):
112
+ tokens = ocr_tokens.clone()
113
+ indices = torch.where(confidence < 0.5)
114
+ for i, j in zip(indices[0], indices[1]):
115
+ if i != 6:
116
+ continue
117
+ tokens[i, j] = torch.tensor(50264)
118
+ inputs = tokens.reshape(1, -1)
119
+ with torch.no_grad():
120
+ outputs = roberta_model(input_ids=inputs, attention_mask=torch.ones(inputs.shape).to(device))
121
+ lm_logits = outputs.logits
122
+ return lm_logits.reshape(ocr_tokens.shape[0], ocr_tokens.shape[1], -1), indices
123
+
124
+
125
+
126
  cropped_images, y, bounding_box_path = get_cropped_images(image_path)
127
  if debug:
128
  print('Number of cropped images:', len(cropped_images))
129
  generated_texts, logits, gen_tokens = get_model_output(cropped_images)
130
  normalised_scores = get_scores(logits)
131
+ generated_df = pd.DataFrame({
132
+ 'text': generated_texts,
133
+ })
134
  if return_texts == 'generated':
135
  return pd.DataFrame({
136
  'text': generated_texts,
 
152
  return pd.DataFrame(qualified_texts)
153
  final_texts = remove_overlapping_texts(qualified_texts)
154
  final_texts_df = pd.DataFrame(final_texts, columns=['text', 'score', 'y'])
 
155
  final_logits = [text['logits'] for text in final_texts]
156
+ logits = torch.stack([logit for logit in final_logits], dim=0)
157
+ tokens = logits.argmax(-1)
158
+ confidence = logits.softmax(-1).max(-1).values
159
  if return_texts == 'final':
160
  return final_texts_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ lm_logits, indices = get_lm_logits(tokens, confidence)
163
+ combined_logits = logits.clone()
164
+ for i, j in zip(indices[0], indices[1]):
165
+ combined_logits[i, j] = logits[i, j] * 0.9 + lm_logits[i, j] * 0.1
166
 
167
+ return final_texts_df, bounding_box_path, tokens, combined_logits, confidence, generated_df
168
 
169
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ def process_image(image):
173
+ text, bounding_path = "", ""
174
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_image:
175
+ image.save(temp_image.name)
176
+ image_path = temp_image.name
177
+ df, bounding_path, tokens, logits, confidence, generated_df = inference(image_path, debug=False, return_texts='final_v2')
178
+ text = df['text'].str.cat(sep='\n')
179
+ before_text = generated_df['text'].str.cat(sep='\n')
180
+ bounding_img = Image.open(bounding_path)
181
+ return bounding_img, before_text, text
182
 
183
+ # Define Gradio Interface
184
  interface = gr.Interface(
185
+ fn=process_image, # Call the process_image function
186
+ inputs=gr.Image(type="pil"), # Expect an image input
187
+ outputs=[
188
+ gr.Image(type="pil", label="Bounding Box Image"),
189
+ gr.Textbox(label="Extracted Text"),
190
+ gr.Textbox(label="Post Processed Text"),
191
+ ],
192
+ title="OCR Pipeline with YOLO, TrOCR and Roberta",
193
+ description="Upload an image to detect text regions with YOLO, merge bounding boxes, and extract text using TrOCR which is then preprocessed with Roberta for contextual understanding.",
194
  )
195
 
196
+ # Launch the interface
197
+ if __name__ == "__main__":
198
+ interface.launch()