curiousily commited on
Commit
9262ebb
1 Parent(s): 6e18cfe

Add app files

Browse files
Files changed (2) hide show
  1. app.py +119 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import pandas as pd
4
+ import plotly.express as px
5
+ import streamlit as st
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from easyocr import Reader
9
+ from PIL import Image
10
+ from transformers import (
11
+ LayoutLMv3FeatureExtractor,
12
+ LayoutLMv3ForSequenceClassification,
13
+ LayoutLMv3Processor,
14
+ LayoutLMv3TokenizerFast,
15
+ )
16
+
17
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
18
+ MICROSOFT_MODEL_NAME = "microsoft/layoutlmv3-base"
19
+ MODEL_NAME = "curiousily/layoutlmv3-financial-document-classification"
20
+
21
+
22
+ def create_bounding_box(bbox_data, width_scale: float, height_scale: float):
23
+ xs = []
24
+ ys = []
25
+ for x, y in bbox_data:
26
+ xs.append(x)
27
+ ys.append(y)
28
+
29
+ left = int(min(xs) * width_scale)
30
+ top = int(min(ys) * height_scale)
31
+ right = int(max(xs) * width_scale)
32
+ bottom = int(max(ys) * height_scale)
33
+
34
+ return [left, top, right, bottom]
35
+
36
+
37
+ @st.experimental_singleton
38
+ def create_ocr_reader():
39
+ return Reader(["en"])
40
+
41
+
42
+ @st.experimental_singleton
43
+ def create_processor():
44
+ feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
45
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MICROSOFT_MODEL_NAME)
46
+ return LayoutLMv3Processor(feature_extractor, tokenizer)
47
+
48
+
49
+ @st.experimental_singleton
50
+ def create_model():
51
+ model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
52
+ return model.eval().to(DEVICE)
53
+
54
+
55
+ def predict(
56
+ image: Image,
57
+ reader: Reader,
58
+ processor: LayoutLMv3Processor,
59
+ model: LayoutLMv3ForSequenceClassification,
60
+ ):
61
+ width, height = image.size
62
+ ocr_result = reader.readtext(image)
63
+
64
+ width_scale = 1000 / width
65
+ height_scale = 1000 / height
66
+
67
+ words = []
68
+ boxes = []
69
+ for bbox, word, confidence in ocr_result:
70
+ words.append(word)
71
+ boxes.append(create_bounding_box(bbox, width_scale, height_scale))
72
+
73
+ encoding = processor(
74
+ image,
75
+ words,
76
+ boxes=boxes,
77
+ max_length=512,
78
+ padding="max_length",
79
+ truncation=True,
80
+ return_tensors="pt",
81
+ )
82
+
83
+ with torch.inference_mode():
84
+ output = model(
85
+ input_ids=encoding["input_ids"].to(DEVICE),
86
+ attention_mask=encoding["attention_mask"].to(DEVICE),
87
+ bbox=encoding["bbox"].to(DEVICE),
88
+ pixel_values=encoding["pixel_values"].to(DEVICE),
89
+ )
90
+ logits = output.logits
91
+ predicted_class = logits.argmax()
92
+ probabilities = F.softmax(logits, dim=-1).flatten().tolist()
93
+ return predicted_class, probabilities
94
+
95
+
96
+ reader = create_ocr_reader()
97
+ processor = create_processor()
98
+ model = create_model()
99
+
100
+ uploaded_file = st.file_uploader("Upload Document image", ["jpg", "png"])
101
+ if uploaded_file is not None:
102
+ bytes_data = io.BytesIO(uploaded_file.getvalue())
103
+ image = Image.open(bytes_data)
104
+ predicted_class, probabilities = predict(image, reader, processor, model)
105
+ predicted_label = model.config.id2label[predicted_class.item()]
106
+
107
+ st.image(image, "Your document image")
108
+ st.markdown(f"Predicted document type: **{predicted_label}**")
109
+
110
+ df_predictions = pd.DataFrame(
111
+ {"Document": list(model.config.id2label.values()), "Confidence": probabilities}
112
+ )
113
+
114
+ fig = px.bar(
115
+ df_predictions,
116
+ x="Document",
117
+ y="Confidence",
118
+ )
119
+ st.plotly_chart(fig, use_container_width=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ easyocr==1.6.2
2
+ pandas==1.5.3
3
+ Pillow==9.4.0
4
+ plotly-express==0.4.1
5
+ torch==1.13.1
6
+ transformers==4.25.1