import os os.environ["USE_TORCH"] = "1" os.environ["USE_TF"] = "0" import torch from torch.utils.data.dataloader import DataLoader from builder import DocumentBuilder from trocr import IAMDataset, device, get_processor_model from doctr.utils.visualization import visualize_page from doctr.models.predictor.base import _OCRPredictor from doctr.models.detection.predictor import DetectionPredictor from doctr.models.preprocessor import PreProcessor from doctr.models import db_resnet50, db_mobilenet_v3_large from doctr.io import DocumentFile import numpy as np import cv2 import matplotlib.pyplot as plt import streamlit as st DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large"] RECO_ARCHS = ["microsoft/trocr-large-printed", "microsoft/trocr-large-stage1", "microsoft/trocr-large-handwritten"] def main(): # Wide mode st.set_page_config(layout="wide") # Designing the interface st.title("docTR + TrOCR") # For newline st.write('\n') # st.write('For Detection DocTR: https://github.com/mindee/doctr') # For newline st.write('\n') st.write('For Recognition TrOCR: https://github.com/microsoft/unilm/tree/master/trocr') # For newline st.write('\n') st.write('Any Issue please dm') # For newline st.write('\n') # Instructions st.markdown( "*Hint: click on the top-right corner of an image to enlarge it!*") # Set the columns cols = st.columns((1, 1, 1)) cols[0].subheader("Input page") cols[1].subheader("Segmentation heatmap") # Sidebar # File selection st.sidebar.title("Document selection") # Disabling warning st.set_option('deprecation.showfileUploaderEncoding', False) # Choose your own image uploaded_file = st.sidebar.file_uploader( "Upload files", type=['pdf', 'png', 'jpeg', 'jpg']) if uploaded_file is not None: if uploaded_file.name.endswith('.pdf'): doc = DocumentFile.from_pdf(uploaded_file.read()).as_images() else: doc = DocumentFile.from_images(uploaded_file.read()) page_idx = st.sidebar.selectbox( "Page selection", [idx + 1 for idx in range(len(doc))]) - 1 cols[0].image(doc[page_idx]) # Model selection st.sidebar.title("Model selection") det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS) rec_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS) # For newline st.sidebar.write('\n') if st.sidebar.button("Analyze page"): if uploaded_file is None: st.sidebar.write("Please upload a document") else: with st.spinner('Loading model...'): if det_arch == "db_resnet50": det_model = db_resnet50(pretrained=True) else: det_model = db_mobilenet_v3_large(pretrained=True) det_predictor = DetectionPredictor(PreProcessor((1024, 1024), batch_size=1, mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)), det_model) rec_processor, rec_model = get_processor_model(rec_arch) with st.spinner('Analyzing...'): # Forward the image to the model processed_batches = det_predictor.pre_processor([doc[page_idx]]) out = det_predictor.model(processed_batches[0], return_model_output=True) seg_map = out["out_map"] seg_map = torch.squeeze(seg_map[0, ...], axis=0) seg_map = cv2.resize(seg_map.detach().numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]), interpolation=cv2.INTER_LINEAR) # Plot the raw heatmap fig, ax = plt.subplots() ax.imshow(seg_map) ax.axis('off') cols[1].pyplot(fig) # Plot OCR output # Localize text elements loc_preds = out["preds"] # Check whether crop mode should be switched to channels first channels_last = len(doc) == 0 or isinstance(doc[0], np.ndarray) # Crop images crops, loc_preds = _OCRPredictor._prepare_crops( doc, loc_preds, channels_last=channels_last, assume_straight_pages=True ) test_dataset = IAMDataset(crops[0], rec_processor) test_dataloader = DataLoader(test_dataset, batch_size=16) text = [] with torch.no_grad(): for batch in test_dataloader: pixel_values = batch["pixel_values"].to(device) generated_ids = rec_model.generate(pixel_values) generated_text = rec_processor.batch_decode( generated_ids, skip_special_tokens=True) text.extend(generated_text) boxes, text_preds = _OCRPredictor._process_predictions( loc_preds, text) doc_builder = DocumentBuilder() out = doc_builder( boxes, text_preds, [ # type: ignore[misc] page.shape[:2] if channels_last else page.shape[-2:] for page in [doc[page_idx]] ] ) for df in out: st.markdown("text") st.write(" ".join(df["word"].to_list())) st.write('\n') st.markdown("\n Dataframe Output- similar to Tesseract:") st.dataframe(df) if __name__ == '__main__': main()