Spaces:
Runtime error
Runtime error
File size: 5,760 Bytes
c46149c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
os.environ["USE_TORCH"] = "1"
os.environ["USE_TF"] = "0"
import torch
from torch.utils.data.dataloader import DataLoader
from builder import DocumentBuilder
from trocr import IAMDataset, device, get_processor_model
from doctr.utils.visualization import visualize_page
from doctr.models.predictor.base import _OCRPredictor
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.preprocessor import PreProcessor
from doctr.models import db_resnet50, db_mobilenet_v3_large
from doctr.io import DocumentFile
import numpy as np
import cv2
import matplotlib.pyplot as plt
import streamlit as st
DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large"]
RECO_ARCHS = ["microsoft/trocr-large-printed", "microsoft/trocr-large-stage1", "microsoft/trocr-large-handwritten"]
def main():
# Wide mode
st.set_page_config(layout="wide")
# Designing the interface
st.title("docTR + TrOCR")
# For newline
st.write('\n')
#
st.write('For Detection DocTR: https://github.com/mindee/doctr')
# For newline
st.write('\n')
st.write('For Recognition TrOCR: https://github.com/microsoft/unilm/tree/master/trocr')
# For newline
st.write('\n')
st.write('Any Issue please dm')
# For newline
st.write('\n')
# Instructions
st.markdown(
"*Hint: click on the top-right corner of an image to enlarge it!*")
# Set the columns
cols = st.columns((1, 1, 1))
cols[0].subheader("Input page")
cols[1].subheader("Segmentation heatmap")
# Sidebar
# File selection
st.sidebar.title("Document selection")
# Disabling warning
st.set_option('deprecation.showfileUploaderEncoding', False)
# Choose your own image
uploaded_file = st.sidebar.file_uploader(
"Upload files", type=['pdf', 'png', 'jpeg', 'jpg'])
if uploaded_file is not None:
if uploaded_file.name.endswith('.pdf'):
doc = DocumentFile.from_pdf(uploaded_file.read()).as_images()
else:
doc = DocumentFile.from_images(uploaded_file.read())
page_idx = st.sidebar.selectbox(
"Page selection", [idx + 1 for idx in range(len(doc))]) - 1
cols[0].image(doc[page_idx])
# Model selection
st.sidebar.title("Model selection")
det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS)
rec_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS)
# For newline
st.sidebar.write('\n')
if st.sidebar.button("Analyze page"):
if uploaded_file is None:
st.sidebar.write("Please upload a document")
else:
with st.spinner('Loading model...'):
if det_arch == "db_resnet50":
det_model = db_resnet50(pretrained=True)
else:
det_model = db_mobilenet_v3_large(pretrained=True)
det_predictor = DetectionPredictor(PreProcessor((1024, 1024), batch_size=1, mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)), det_model)
rec_processor, rec_model = get_processor_model(rec_arch)
with st.spinner('Analyzing...'):
# Forward the image to the model
processed_batches = det_predictor.pre_processor([doc[page_idx]])
out = det_predictor.model(processed_batches[0], return_model_output=True)
seg_map = out["out_map"]
seg_map = torch.squeeze(seg_map[0, ...], axis=0)
seg_map = cv2.resize(seg_map.detach().numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]),
interpolation=cv2.INTER_LINEAR)
# Plot the raw heatmap
fig, ax = plt.subplots()
ax.imshow(seg_map)
ax.axis('off')
cols[1].pyplot(fig)
# Plot OCR output
# Localize text elements
loc_preds = out["preds"]
# Check whether crop mode should be switched to channels first
channels_last = len(doc) == 0 or isinstance(doc[0], np.ndarray)
# Crop images
crops, loc_preds = _OCRPredictor._prepare_crops(
doc, loc_preds, channels_last=channels_last, assume_straight_pages=True
)
test_dataset = IAMDataset(crops[0], rec_processor)
test_dataloader = DataLoader(test_dataset, batch_size=16)
text = []
with torch.no_grad():
for batch in test_dataloader:
pixel_values = batch["pixel_values"].to(device)
generated_ids = rec_model.generate(pixel_values)
generated_text = rec_processor.batch_decode(
generated_ids, skip_special_tokens=True)
text.extend(generated_text)
boxes, text_preds = _OCRPredictor._process_predictions(
loc_preds, text)
doc_builder = DocumentBuilder()
out = doc_builder(
boxes,
text_preds,
[
# type: ignore[misc]
page.shape[:2] if channels_last else page.shape[-2:]
for page in [doc[page_idx]]
]
)
for df in out:
st.markdown("text")
st.write(" ".join(df["word"].to_list()))
st.write('\n')
st.markdown("\n Dataframe Output- similar to Tesseract:")
st.dataframe(df)
if __name__ == '__main__':
main() |