Spaces:
Running
Running
madhavkotecha
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
|
2 |
from ultralytics import YOLO
|
3 |
-
from transformers import TrOCRProcessor, VisionEncoderDecoderModel,
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
|
|
7 |
from nltk.translate import bleu_score
|
8 |
from nltk.translate.bleu_score import SmoothingFunction
|
9 |
import torch
|
10 |
-
import gradio as gr
|
11 |
|
12 |
yolo_weights_path = "final_wts.pt"
|
13 |
|
@@ -15,15 +15,16 @@ device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is
|
|
15 |
|
16 |
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
|
17 |
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)
|
18 |
-
trocr_model.config.num_beams =
|
19 |
|
20 |
-
yolo_model = YOLO(yolo_weights_path).to(
|
21 |
-
unmasker_large = pipeline('fill-mask', model='roberta-large', device=device)
|
22 |
roberta_model = AutoModelForMaskedLM.from_pretrained("roberta-large").to(device)
|
23 |
|
24 |
-
print(f'TrOCR and YOLO Models loaded on {device}')
|
25 |
|
|
|
26 |
|
|
|
|
|
27 |
|
28 |
|
29 |
CONFIDENCE_THRESHOLD = 0.72
|
@@ -61,7 +62,7 @@ def inference(image_path, debug=False, return_texts='final'):
|
|
61 |
for i in range(len(generated_texts)):
|
62 |
if len(generated_texts[i]) > 2 and generated_texts[i][:2] == '# ':
|
63 |
generated_texts[i] = generated_texts[i][2:]
|
64 |
-
|
65 |
if len(generated_texts[i]) > 2 and generated_texts[i][-2:] == ' #':
|
66 |
generated_texts[i] = generated_texts[i][:-2]
|
67 |
return generated_texts
|
@@ -107,11 +108,29 @@ def inference(image_path, debug=False, return_texts='final'):
|
|
107 |
new = qualified_texts[i]['bleu'] < BLEU_THRESHOLD
|
108 |
return final_texts
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
cropped_images, y, bounding_box_path = get_cropped_images(image_path)
|
111 |
if debug:
|
112 |
print('Number of cropped images:', len(cropped_images))
|
113 |
generated_texts, logits, gen_tokens = get_model_output(cropped_images)
|
114 |
normalised_scores = get_scores(logits)
|
|
|
|
|
|
|
115 |
if return_texts == 'generated':
|
116 |
return pd.DataFrame({
|
117 |
'text': generated_texts,
|
@@ -133,81 +152,47 @@ def inference(image_path, debug=False, return_texts='final'):
|
|
133 |
return pd.DataFrame(qualified_texts)
|
134 |
final_texts = remove_overlapping_texts(qualified_texts)
|
135 |
final_texts_df = pd.DataFrame(final_texts, columns=['text', 'score', 'y'])
|
136 |
-
final_tokens = [text['tokens'] for text in final_texts]
|
137 |
final_logits = [text['logits'] for text in final_texts]
|
|
|
|
|
|
|
138 |
if return_texts == 'final':
|
139 |
return final_texts_df
|
140 |
-
|
141 |
-
return final_texts_df, bounding_box_path, final_tokens, final_logits, generated_texts
|
142 |
-
|
143 |
-
|
144 |
-
# image_path = "raw_dataset/g06-037h.png"
|
145 |
-
# df, bounding_path, tokens, logits, gen_texts = inference(image_path, debug=False, return_texts='final_v2')
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
def get_new_logits(tokens):
|
150 |
-
inputs = tokens.reshape(1, -1)
|
151 |
-
# Get the logits from the model
|
152 |
-
with torch.no_grad():
|
153 |
-
outputs = roberta_model(input_ids=inputs, attention_mask=torch.ones(inputs.shape).to(device))
|
154 |
-
logits = outputs.logits
|
155 |
-
|
156 |
-
|
157 |
-
logits_flattened = logits.reshape(-1, slogits.shape[-1])
|
158 |
-
print(processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True))
|
159 |
-
return logits.reshape(tokens.shape + (logits.shape[-1],))
|
160 |
-
|
161 |
-
|
162 |
-
slogits = torch.stack([logit for logit in logits], dim=0)
|
163 |
-
tokens = slogits.argmax(-1)
|
164 |
-
confidence = slogits.softmax(-1).max(-1).values
|
165 |
-
indices = torch.where(confidence < 0.5)
|
166 |
-
# put 50264(mask) when confidence < 0.5
|
167 |
-
for i, j in zip(indices[0], indices[1]):
|
168 |
-
if i != 6:
|
169 |
-
continue
|
170 |
-
tokens[i, j] = torch.tensor(50264)
|
171 |
-
|
172 |
-
new_logits = get_new_logits(tokens)
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
for i, j in zip(indices[0], indices[1]):
|
178 |
-
slogits[i, j] = slogits[i, j] * 0.1 + new_logits[i, j] * 0.5
|
179 |
-
|
180 |
-
logits_flattened = slogits.reshape(-1, slogits.shape[-1])
|
181 |
-
processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True)
|
182 |
|
|
|
|
|
|
|
|
|
183 |
|
|
|
184 |
|
185 |
|
186 |
|
187 |
-
def gradio_inference(image_path):
|
188 |
-
"""
|
189 |
-
Function to handle inference and output the generated texts and final processed texts.
|
190 |
-
"""
|
191 |
-
df, bounding_path, tokens, logits, gen_texts = inference(image_path, debug=False, return_texts='final_v2')
|
192 |
-
|
193 |
-
# Convert the DataFrame for final texts to a readable format
|
194 |
-
final_texts = df.to_string(index=False)
|
195 |
-
|
196 |
-
# Convert the list of generated texts into a readable string
|
197 |
-
gen_texts_output = '\n'.join(gen_texts)
|
198 |
-
|
199 |
-
return gen_texts_output, final_texts
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
|
|
205 |
interface = gr.Interface(
|
206 |
-
fn=
|
207 |
-
inputs=
|
208 |
-
outputs=[
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
211 |
)
|
212 |
|
213 |
-
interface
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
from ultralytics import YOLO
|
3 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoModelForMaskedLM
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
7 |
+
import tempfile
|
8 |
from nltk.translate import bleu_score
|
9 |
from nltk.translate.bleu_score import SmoothingFunction
|
10 |
import torch
|
|
|
11 |
|
12 |
yolo_weights_path = "final_wts.pt"
|
13 |
|
|
|
15 |
|
16 |
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
|
17 |
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)
|
18 |
+
trocr_model.config.num_beams = 2
|
19 |
|
20 |
+
yolo_model = YOLO(yolo_weights_path).to(device)
|
|
|
21 |
roberta_model = AutoModelForMaskedLM.from_pretrained("roberta-large").to(device)
|
22 |
|
|
|
23 |
|
24 |
+
print(f'TrOCR, YOLO and Roberta Models loaded on {device}')
|
25 |
|
26 |
+
CONFIDENCE_THRESHOLD = 0.72
|
27 |
+
BLEU_THRESHOLD = 0.6
|
28 |
|
29 |
|
30 |
CONFIDENCE_THRESHOLD = 0.72
|
|
|
62 |
for i in range(len(generated_texts)):
|
63 |
if len(generated_texts[i]) > 2 and generated_texts[i][:2] == '# ':
|
64 |
generated_texts[i] = generated_texts[i][2:]
|
65 |
+
|
66 |
if len(generated_texts[i]) > 2 and generated_texts[i][-2:] == ' #':
|
67 |
generated_texts[i] = generated_texts[i][:-2]
|
68 |
return generated_texts
|
|
|
108 |
new = qualified_texts[i]['bleu'] < BLEU_THRESHOLD
|
109 |
return final_texts
|
110 |
|
111 |
+
def get_lm_logits(ocr_tokens, confidence):
|
112 |
+
tokens = ocr_tokens.clone()
|
113 |
+
indices = torch.where(confidence < 0.5)
|
114 |
+
for i, j in zip(indices[0], indices[1]):
|
115 |
+
if i != 6:
|
116 |
+
continue
|
117 |
+
tokens[i, j] = torch.tensor(50264)
|
118 |
+
inputs = tokens.reshape(1, -1)
|
119 |
+
with torch.no_grad():
|
120 |
+
outputs = roberta_model(input_ids=inputs, attention_mask=torch.ones(inputs.shape).to(device))
|
121 |
+
lm_logits = outputs.logits
|
122 |
+
return lm_logits.reshape(ocr_tokens.shape[0], ocr_tokens.shape[1], -1), indices
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
cropped_images, y, bounding_box_path = get_cropped_images(image_path)
|
127 |
if debug:
|
128 |
print('Number of cropped images:', len(cropped_images))
|
129 |
generated_texts, logits, gen_tokens = get_model_output(cropped_images)
|
130 |
normalised_scores = get_scores(logits)
|
131 |
+
generated_df = pd.DataFrame({
|
132 |
+
'text': generated_texts,
|
133 |
+
})
|
134 |
if return_texts == 'generated':
|
135 |
return pd.DataFrame({
|
136 |
'text': generated_texts,
|
|
|
152 |
return pd.DataFrame(qualified_texts)
|
153 |
final_texts = remove_overlapping_texts(qualified_texts)
|
154 |
final_texts_df = pd.DataFrame(final_texts, columns=['text', 'score', 'y'])
|
|
|
155 |
final_logits = [text['logits'] for text in final_texts]
|
156 |
+
logits = torch.stack([logit for logit in final_logits], dim=0)
|
157 |
+
tokens = logits.argmax(-1)
|
158 |
+
confidence = logits.softmax(-1).max(-1).values
|
159 |
if return_texts == 'final':
|
160 |
return final_texts_df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
+
lm_logits, indices = get_lm_logits(tokens, confidence)
|
163 |
+
combined_logits = logits.clone()
|
164 |
+
for i, j in zip(indices[0], indices[1]):
|
165 |
+
combined_logits[i, j] = logits[i, j] * 0.9 + lm_logits[i, j] * 0.1
|
166 |
|
167 |
+
return final_texts_df, bounding_box_path, tokens, combined_logits, confidence, generated_df
|
168 |
|
169 |
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
+
def process_image(image):
|
173 |
+
text, bounding_path = "", ""
|
174 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_image:
|
175 |
+
image.save(temp_image.name)
|
176 |
+
image_path = temp_image.name
|
177 |
+
df, bounding_path, tokens, logits, confidence, generated_df = inference(image_path, debug=False, return_texts='final_v2')
|
178 |
+
text = df['text'].str.cat(sep='\n')
|
179 |
+
before_text = generated_df['text'].str.cat(sep='\n')
|
180 |
+
bounding_img = Image.open(bounding_path)
|
181 |
+
return bounding_img, before_text, text
|
182 |
|
183 |
+
# Define Gradio Interface
|
184 |
interface = gr.Interface(
|
185 |
+
fn=process_image, # Call the process_image function
|
186 |
+
inputs=gr.Image(type="pil"), # Expect an image input
|
187 |
+
outputs=[
|
188 |
+
gr.Image(type="pil", label="Bounding Box Image"),
|
189 |
+
gr.Textbox(label="Extracted Text"),
|
190 |
+
gr.Textbox(label="Post Processed Text"),
|
191 |
+
],
|
192 |
+
title="OCR Pipeline with YOLO, TrOCR and Roberta",
|
193 |
+
description="Upload an image to detect text regions with YOLO, merge bounding boxes, and extract text using TrOCR which is then preprocessed with Roberta for contextual understanding.",
|
194 |
)
|
195 |
|
196 |
+
# Launch the interface
|
197 |
+
if __name__ == "__main__":
|
198 |
+
interface.launch()
|