Ankur Goyal commited on
Commit
2359223
1 Parent(s): d229b67

Switch to Gradio

Browse files
Files changed (5) hide show
  1. README.md +2 -4
  2. app.py +137 -141
  3. contract.jpeg +0 -0
  4. invoice.png +0 -0
  5. statement.png +0 -0
README.md CHANGED
@@ -3,10 +3,8 @@ title: DocQuery
3
  emoji: 🦉
4
  colorFrom: gray
5
  colorTo: pink
6
- sdk: streamlit
7
- sdk_version: 1.10.0
8
  app_file: app.py
9
  pinned: true
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
3
  emoji: 🦉
4
  colorFrom: gray
5
  colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.1.7
8
  app_file: app.py
9
  pinned: true
10
  ---
 
 
app.py CHANGED
@@ -2,15 +2,13 @@ import os
2
 
3
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
 
5
- from PIL import ImageDraw
6
- import streamlit as st
7
- from st_clickable_images import clickable_images
8
-
9
- st.set_page_config(layout="wide")
10
 
11
  import torch
12
  from docquery.pipeline import get_pipeline
13
- from docquery.document import load_bytes, load_document
14
 
15
 
16
  def ensure_list(x):
@@ -25,15 +23,21 @@ CHECKPOINTS = {
25
  "Donut 🍩": "naver-clova-ix/donut-base-finetuned-docvqa",
26
  }
27
 
 
 
28
 
29
- @st.experimental_singleton(show_spinner=False)
30
  def construct_pipeline(model):
 
 
 
 
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
  ret = get_pipeline(checkpoint=CHECKPOINTS[model], device=device)
 
33
  return ret
34
 
35
 
36
- @st.cache(show_spinner=False)
37
  def run_pipeline(model, question, document, top_k):
38
  pipeline = construct_pipeline(model)
39
  return pipeline(question=question, **document.context, top_k=top_k)
@@ -59,150 +63,142 @@ def normalize_bbox(box, width, height):
59
  return [pct[0] * width, pct[1] * height, pct[2] * width, pct[3] * height]
60
 
61
 
62
- st.markdown("# DocQuery: Query Documents w/ NLP")
63
-
64
- if "document" not in st.session_state:
65
- st.session_state["document"] = None
 
 
 
 
 
 
 
 
 
 
66
 
67
- if "last_clicked" not in st.session_state:
68
- st.session_state["last_clicked"] = None
69
 
70
- input_col, model_col = st.columns(2)
 
 
 
 
 
 
 
71
 
72
- with input_col:
73
- input_type = st.radio(
74
- "Pick an input type", ["Upload", "URL", "Examples"], horizontal=True
75
- )
76
 
77
- with model_col:
78
- model_type = st.radio("Pick a model", list(CHECKPOINTS.keys()), horizontal=True)
 
 
 
79
 
80
 
81
- def load_file_cb():
82
- if st.session_state.file_input is None:
83
- return
84
 
85
- file = st.session_state.file_input
86
- with loading_placeholder:
87
- with st.spinner("Processing..."):
88
- document = load_bytes(file, file.name)
89
- _ = document.context
90
- st.session_state.document = document
91
 
 
 
 
92
 
93
- def load_url_cb():
94
- if st.session_state.url_input is None:
95
- return
 
 
 
 
 
96
 
97
- url = st.session_state.url_input
98
- with loading_placeholder:
99
- with st.spinner("Downloading..."):
100
- document = load_document(url)
101
- with st.spinner("Processing..."):
102
- _ = document.context
103
- st.session_state.document = document
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
 
 
 
 
 
105
 
106
- examples = [
107
- (
108
- "https://templates.invoicehome.com/invoice-template-us-neat-750px.png",
109
- "What is the invoice number?",
110
- ),
111
- (
112
- "https://miro.medium.com/max/787/1*iECQRIiOGTmEFLdWkVIH2g.jpeg",
113
- "What is the purchase amount?",
114
- ),
115
- (
116
- "https://www.accountingcoach.com/wp-content/uploads/2013/10/income-statement-example@2x.png",
117
- "What are net sales for 2020?",
118
- ),
119
- ]
120
- imgs_clicked = []
121
 
122
- if input_type == "Upload":
123
- file = st.file_uploader(
124
- "Upload a PDF or Image document", key="file_input", on_change=load_file_cb
 
125
  )
126
- elif input_type == "URL":
127
- url = st.text_input("URL", "", key="url_input", on_change=load_url_cb)
128
- elif input_type == "Examples":
129
- example_cols = st.columns(len(examples))
130
- for (i, (path, question)) in enumerate(examples):
131
- with example_cols[i]:
132
- imgs_clicked.append(
133
- clickable_images(
134
- [path],
135
- div_style={
136
- "display": "flex",
137
- "justify-content": "center",
138
- "flex-wrap": "wrap",
139
- "cursor": "pointer",
140
- },
141
- img_style={"margin": "5px", "height": "200px"},
142
- )
143
- )
144
- st.markdown(
145
- f"<p style='text-align: center'>{question}</p>",
146
- unsafe_allow_html=True,
147
- )
148
- print(imgs_clicked)
149
- imgs_clicked = [-1] * len(imgs_clicked)
150
-
151
- # clicked = clickable_images(
152
- # [x[0] for x in examples],
153
- # titles=[x[1] for x in examples],
154
- # div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
155
- # img_style={"margin": "5px", "height": "200px"},
156
- # )
157
- #
158
- # st.markdown(f"Image #{clicked} clicked" if clicked > -1 else "No image clicked")
159
-
160
-
161
- question = st.text_input("QUESTION", "", key="question")
162
-
163
- document = st.session_state.document
164
- loading_placeholder = st.empty()
165
- if document is not None:
166
- col1, col2 = st.columns(2)
167
- image = document.preview
168
-
169
- question = st.session_state.question
170
- colors = ["blue", "red", "green"]
171
- if document is not None and question is not None and len(question) > 0:
172
- col2.header(f"Answers ({model_type})")
173
- with col2:
174
- answers_placeholder = st.container()
175
- answers_loading_placeholder = st.container()
176
-
177
- with answers_loading_placeholder:
178
- # Run this (one-time) expensive operation outside of the processing
179
- # question placeholder
180
- with st.spinner("Constructing pipeline..."):
181
- construct_pipeline(model_type)
182
-
183
- with st.spinner("Processing question..."):
184
- predictions = run_pipeline(
185
- model=model_type, question=question, document=document, top_k=1
186
- )
187
-
188
- with answers_placeholder:
189
- image = image.copy()
190
- draw = ImageDraw.Draw(image)
191
- for i, p in enumerate(ensure_list(predictions)):
192
- col2.markdown(f"#### { p['answer'] }: ({round(p['score'] * 100, 1)}%)")
193
- if "start" in p and "end" in p:
194
- x1, y1, x2, y2 = normalize_bbox(
195
- expand_bbox(
196
- lift_word_boxes(document)[p["start"] : p["end"] + 1]
197
- ),
198
- image.width,
199
- image.height,
200
- )
201
- draw.rectangle(((x1, y1), (x2, y2)), outline=colors[i], width=3)
202
-
203
- if document is not None:
204
- col1.image(image, use_column_width="auto")
205
-
206
- "DocQuery uses LayoutLMv1 fine-tuned on DocVQA, a document visual question answering dataset, as well as SQuAD, which boosts its English-language comprehension. To use it, simply upload an image or PDF, type a question, and click 'submit', or click one of the examples to load them."
207
-
208
- "[Github Repo](https://github.com/impira/docquery)"
 
2
 
3
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
 
5
+ import functools
6
+ from PIL import Image, ImageDraw
7
+ import gradio as gr
 
 
8
 
9
  import torch
10
  from docquery.pipeline import get_pipeline
11
+ from docquery.document import load_bytes, load_document, ImageDocument
12
 
13
 
14
  def ensure_list(x):
 
23
  "Donut 🍩": "naver-clova-ix/donut-base-finetuned-docvqa",
24
  }
25
 
26
+ PIPELINES = {}
27
+
28
 
 
29
  def construct_pipeline(model):
30
+ global PIPELINES
31
+ if model in PIPELINES:
32
+ return PIPELINES[model]
33
+
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
  ret = get_pipeline(checkpoint=CHECKPOINTS[model], device=device)
36
+ PIPELINES[model] = ret
37
  return ret
38
 
39
 
40
+ @functools.lru_cache(1024)
41
  def run_pipeline(model, question, document, top_k):
42
  pipeline = construct_pipeline(model)
43
  return pipeline(question=question, **document.context, top_k=top_k)
 
63
  return [pct[0] * width, pct[1] * height, pct[2] * width, pct[3] * height]
64
 
65
 
66
+ examples = [
67
+ [
68
+ "invoice.png",
69
+ "What is the invoice number?",
70
+ ],
71
+ [
72
+ "contract.jpeg",
73
+ "What is the purchase amount?",
74
+ ],
75
+ [
76
+ "statement.png",
77
+ "What are net sales for 2020?",
78
+ ],
79
+ ]
80
 
 
 
81
 
82
+ def process_path(path):
83
+ if path:
84
+ try:
85
+ document = load_document(path)
86
+ return document, document.preview, None
87
+ except Exception:
88
+ pass
89
+ return None, None, None
90
 
 
 
 
 
91
 
92
+ def process_upload(file):
93
+ if file:
94
+ return process_path(file.name)
95
+ else:
96
+ return None, None, None
97
 
98
 
99
+ colors = ["blue", "green", "black"]
 
 
100
 
 
 
 
 
 
 
101
 
102
+ def process_question(question, document, model=list(CHECKPOINTS.keys())[0]):
103
+ if document is None:
104
+ return None, None
105
 
106
+ predictions = run_pipeline(model, question, document, 3)
107
+ image = document.preview.copy()
108
+ draw = ImageDraw.Draw(image)
109
+ for i, p in enumerate(ensure_list(predictions)):
110
+ if i > 0:
111
+ # Keep the code around to produce multiple boxes, but only show the top
112
+ # prediction for now
113
+ break
114
 
115
+ if "start" in p and "end" in p:
116
+ x1, y1, x2, y2 = normalize_bbox(
117
+ expand_bbox(lift_word_boxes(document)[p["start"] : p["end"] + 1]),
118
+ image.width,
119
+ image.height,
120
+ )
121
+ draw.rectangle(((x1, y1), (x2, y2)), outline=colors[i], width=2)
122
+
123
+ return image, predictions
124
+
125
+
126
+ def load_example_document(img, question, model):
127
+ document = ImageDocument(Image.fromarray(img))
128
+ preview, answer = process_question(question, document, model)
129
+ return document, question, preview, answer
130
+
131
+
132
+ with gr.Blocks() as demo:
133
+ gr.Markdown("# DocQuery: Query Documents w/ NLP")
134
+ document = gr.Variable()
135
+ example_question = gr.Textbox(visible=False)
136
+ example_image = gr.Image(visible=False)
137
+
138
+ gr.Markdown("## 1. Upload a file or select an example")
139
+ with gr.Row(equal_height=True):
140
+ with gr.Column():
141
+ upload = gr.File(label="Upload a file", interactive=True)
142
+ url = gr.Textbox(label="... or a URL", interactive=True)
143
+ gr.Examples(
144
+ examples=examples,
145
+ inputs=[example_image, example_question],
146
+ )
147
+
148
+ gr.Markdown("## 2. Ask a question")
149
+
150
+ with gr.Row(equal_height=True):
151
+ # NOTE: When https://github.com/gradio-app/gradio/issues/2103 is resolved,
152
+ # we can support enter-key submit
153
+ question = gr.Textbox(
154
+ label="Question", placeholder="e.g. What is the invoice number?"
155
+ )
156
+ model = gr.Radio(
157
+ choices=list(CHECKPOINTS.keys()),
158
+ value=list(CHECKPOINTS.keys())[0],
159
+ label="Model",
160
+ )
161
+
162
+ with gr.Row():
163
+ clear_button = gr.Button("Clear", variant="secondary")
164
+ submit_button = gr.Button("Submit", variant="primary", elem_id="submit-button")
165
+
166
+ with gr.Row():
167
+ image = gr.Image(visible=True)
168
+ with gr.Column():
169
+ output = gr.JSON(label="Output")
170
+
171
+ clear_button.click(
172
+ lambda _: (None, None, None, None),
173
+ inputs=clear_button,
174
+ outputs=[image, document, question, output],
175
+ )
176
+ upload.change(fn=process_upload, inputs=[upload], outputs=[document, image, output])
177
+ url.change(fn=process_path, inputs=[url], outputs=[document, image, output])
178
 
179
+ submit_button.click(
180
+ process_question,
181
+ inputs=[question, document, model],
182
+ outputs=[image, output],
183
+ )
184
 
185
+ # This is handy but commented out for now because we can't "auto submit" questions either
186
+ # model.change(
187
+ # process_question, inputs=[question, document, model], outputs=[image, output]
188
+ # )
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ example_image.change(
191
+ fn=load_example_document,
192
+ inputs=[example_image, example_question, model],
193
+ outputs=[document, question, image, output],
194
  )
195
+
196
+ gr.Markdown("### More Info")
197
+ gr.Markdown("DocQuery uses LayoutLMv1 fine-tuned on DocVQA, a document visual question"
198
+ " answering dataset, as well as SQuAD, which boosts its English-language comprehension."
199
+ " To use it, simply upload an image or PDF, type a question, and click 'submit', or "
200
+ " click one of the examples to load them.")
201
+ gr.Markdown("[Github Repo](https://github.com/impira/docquery)")
202
+
203
+ if __name__ == "__main__":
204
+ demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
contract.jpeg ADDED
invoice.png ADDED
statement.png ADDED