File size: 5,760 Bytes
c46149c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
os.environ["USE_TORCH"] = "1"
os.environ["USE_TF"] = "0"
import torch
from torch.utils.data.dataloader import DataLoader

from builder import DocumentBuilder
from trocr import IAMDataset, device, get_processor_model
from doctr.utils.visualization import visualize_page
from doctr.models.predictor.base import _OCRPredictor
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.preprocessor import PreProcessor
from doctr.models import db_resnet50, db_mobilenet_v3_large

from doctr.io import DocumentFile
import numpy as np
import cv2
import matplotlib.pyplot as plt
import streamlit as st

DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large"]
RECO_ARCHS = ["microsoft/trocr-large-printed", "microsoft/trocr-large-stage1", "microsoft/trocr-large-handwritten"]


def main():
    # Wide mode
    st.set_page_config(layout="wide")
    # Designing the interface
    st.title("docTR + TrOCR")
    # For newline
    st.write('\n')
    #
    st.write('For Detection DocTR: https://github.com/mindee/doctr')
    # For newline
    st.write('\n')
    st.write('For Recognition TrOCR: https://github.com/microsoft/unilm/tree/master/trocr')
    # For newline
    st.write('\n')
    
    st.write('Any Issue please dm')
    # For newline
    st.write('\n')
    # Instructions
    st.markdown(
        "*Hint: click on the top-right corner of an image to enlarge it!*")
    # Set the columns
    cols = st.columns((1, 1, 1))
    cols[0].subheader("Input page")
    cols[1].subheader("Segmentation heatmap")
    
    # Sidebar
    # File selection
    st.sidebar.title("Document selection")
    # Disabling warning
    st.set_option('deprecation.showfileUploaderEncoding', False)
    # Choose your own image
    uploaded_file = st.sidebar.file_uploader(
        "Upload files", type=['pdf', 'png', 'jpeg', 'jpg'])
    if uploaded_file is not None:
        if uploaded_file.name.endswith('.pdf'):
            doc = DocumentFile.from_pdf(uploaded_file.read()).as_images()
        else:
            doc = DocumentFile.from_images(uploaded_file.read())
        page_idx = st.sidebar.selectbox(
            "Page selection", [idx + 1 for idx in range(len(doc))]) - 1
        cols[0].image(doc[page_idx])
    # Model selection
    st.sidebar.title("Model selection")
    det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS)
    rec_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS)
    # For newline
    st.sidebar.write('\n')
    if st.sidebar.button("Analyze page"):
        if uploaded_file is None:
            st.sidebar.write("Please upload a document")
        else:
            with st.spinner('Loading model...'):
                if det_arch == "db_resnet50":
                    det_model = db_resnet50(pretrained=True)
                else:
                    det_model = db_mobilenet_v3_large(pretrained=True)
                det_predictor = DetectionPredictor(PreProcessor((1024, 1024), batch_size=1, mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)), det_model)
                rec_processor, rec_model = get_processor_model(rec_arch)
            with st.spinner('Analyzing...'):
                # Forward the image to the model
                processed_batches = det_predictor.pre_processor([doc[page_idx]])
                out = det_predictor.model(processed_batches[0], return_model_output=True)
                seg_map = out["out_map"]
                seg_map = torch.squeeze(seg_map[0, ...], axis=0)
                seg_map = cv2.resize(seg_map.detach().numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]),
                                     interpolation=cv2.INTER_LINEAR)
                # Plot the raw heatmap
                fig, ax = plt.subplots()
                ax.imshow(seg_map)
                ax.axis('off')
                cols[1].pyplot(fig)

                # Plot OCR output
                # Localize text elements
                loc_preds = out["preds"]

                # Check whether crop mode should be switched to channels first
                channels_last = len(doc) == 0 or isinstance(doc[0], np.ndarray)

                # Crop images
                crops, loc_preds = _OCRPredictor._prepare_crops(
                    doc, loc_preds, channels_last=channels_last, assume_straight_pages=True
                )

                test_dataset = IAMDataset(crops[0], rec_processor)
                test_dataloader = DataLoader(test_dataset, batch_size=16)

                text = []
                with torch.no_grad():
                    for batch in test_dataloader:
                        pixel_values = batch["pixel_values"].to(device)
                        generated_ids = rec_model.generate(pixel_values)
                        generated_text = rec_processor.batch_decode(
                            generated_ids, skip_special_tokens=True)
                        text.extend(generated_text)
                boxes, text_preds = _OCRPredictor._process_predictions(
                    loc_preds, text)

                doc_builder = DocumentBuilder()
                out = doc_builder(
                    boxes,
                    text_preds,
                    [
                        # type: ignore[misc]
                        page.shape[:2] if channels_last else page.shape[-2:]
                        for page in [doc[page_idx]]
                    ]
                )
                
                for df in out:
                    st.markdown("text")
                    st.write(" ".join(df["word"].to_list()))
                    st.write('\n')
                    st.markdown("\n Dataframe Output- similar to Tesseract:")
                    st.dataframe(df)
                    


if __name__ == '__main__':
    main()