vincentclaes commited on
Commit
05957fd
·
1 Parent(s): 40103b7

first working version

Browse files
.gitattributes CHANGED
@@ -2,13 +2,11 @@
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
  *.npy filter=lfs diff=lfs merge=lfs -text
@@ -16,13 +14,12 @@
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
 
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
 
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
5
  *.ftz filter=lfs diff=lfs merge=lfs -text
6
  *.gz filter=lfs diff=lfs merge=lfs -text
7
  *.h5 filter=lfs diff=lfs merge=lfs -text
8
  *.joblib filter=lfs diff=lfs merge=lfs -text
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
12
  *.npy filter=lfs diff=lfs merge=lfs -text
 
14
  *.onnx filter=lfs diff=lfs merge=lfs -text
15
  *.ot filter=lfs diff=lfs merge=lfs -text
16
  *.parquet filter=lfs diff=lfs merge=lfs -text
 
17
  *.pickle filter=lfs diff=lfs merge=lfs -text
18
  *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
  *.pt filter=lfs diff=lfs merge=lfs -text
21
  *.pth filter=lfs diff=lfs merge=lfs -text
22
  *.rar filter=lfs diff=lfs merge=lfs -text
 
23
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
  *.tar.* filter=lfs diff=lfs merge=lfs -text
25
  *.tflite filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ venv
2
+ *.swo
3
+ *.swp
4
+ *.pyc
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: DocumentQAComparator
3
- emoji: 📈
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.18.0
8
  app_file: app.py
@@ -11,3 +11,7 @@ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
  title: DocumentQAComparator
3
+ emoji: 🤖🦾⚙️
4
+ colorFrom: blue
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.18.0
8
  app_file: app.py
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ ```
16
+ pip install -r requirements.txt
17
+ ```
app.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import boto3
4
+ import traceback
5
+
6
+ import gradio as gr
7
+ from PIL import Image, ImageDraw
8
+
9
+ from docquery.document import load_document, ImageDocument
10
+ from docquery.ocr_reader import get_ocr_reader
11
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
12
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
13
+
14
+ # avoid ssl errors
15
+ import ssl
16
+ ssl._create_default_https_context = ssl._create_unverified_context
17
+
18
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
+
20
+
21
+ def ensure_list(x):
22
+ if isinstance(x, list):
23
+ return x
24
+ else:
25
+ return [x]
26
+
27
+
28
+ CHECKPOINTS = {
29
+ # "LayoutLMv1 🦉": "impira/layoutlm-document-qa",
30
+ # "LayoutLMv1 for Invoices 💸": "impira/layoutlm-invoices",
31
+ "Textract Query": "Textract",
32
+ "LayoutLM FineTuned": "LayoutLM FineTuned",
33
+ "Donut": "naver-clova-ix/donut-base-finetuned-rvlcdip",
34
+ "LiLT": "philschmid/lilt-en-funsd",
35
+ # "LiLT" : "nielsr/lilt-xlm-roberta-base"
36
+ }
37
+
38
+ PIPELINES = {}
39
+ #
40
+ #
41
+ # def construct_pipeline(task, model):
42
+ # global PIPELINES
43
+ # if model in PIPELINES:
44
+ # return PIPELINES[model]
45
+ #
46
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ # ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
48
+ # PIPELINES[model] = ret
49
+ # return ret
50
+
51
+
52
+ def image_to_byte_array(image: Image) -> bytes:
53
+ image_as_byte_array = io.BytesIO()
54
+ image.save(image_as_byte_array, format="PNG")
55
+ image_as_byte_array = image_as_byte_array.getvalue()
56
+ return image_as_byte_array
57
+
58
+
59
+ def run_textract_query(question, document):
60
+ image_as_byte_base64 = image_to_byte_array(image=document.b)
61
+ response = boto3.client('textract').analyze_document(
62
+ Document={
63
+ 'Bytes': image_as_byte_base64,
64
+ },
65
+ FeatureTypes=[
66
+ 'QUERIES',
67
+ ],
68
+ QueriesConfig={
69
+ 'Queries': [
70
+ {
71
+ 'Text': question,
72
+ 'Pages': [
73
+ '*',
74
+ ]
75
+ },
76
+ ]
77
+ }
78
+ )
79
+ for element in response["Blocks"]:
80
+ if element["BlockType"] == "QUERY_RESULT":
81
+ return {
82
+ "score": element["Confidence"],
83
+ "answer": element["Text"],
84
+ # "word_ids": element
85
+ }
86
+ else:
87
+ Exception("No QUERY_RESULT found in the response from Textract.")
88
+
89
+
90
+ def run_layoutlm_finetuned(question, document):
91
+ from transformers import pipeline
92
+
93
+ nlp = pipeline(
94
+ "document-question-answering",
95
+ model="impira/layoutlm-document-qa",
96
+ )
97
+
98
+ result = nlp(document.context["image"][0][0], question)[0]
99
+ # [{'score': 0.9999411106109619, 'answer': 'LETTER OF CREDIT', 'start': 106, 'end': 108}]
100
+ return {
101
+ "score": result["score"],
102
+ "answer": result["answer"],
103
+ "word_ids": [result["start"], result["end"]],
104
+ "page": 0
105
+ }
106
+
107
+
108
+ def run_lilt_model(question, document):
109
+
110
+ # use this model + tokenizer
111
+ lilt_tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-infoxlm-base")
112
+ model = AutoModelForQuestionAnswering.from_pretrained("nielsr/lilt-xlm-roberta-base")
113
+
114
+ processed_document = document.context["image"][0][1]
115
+ words = [x[0] for x in processed_document]
116
+ boxes = [x[1] for x in processed_document]
117
+
118
+ encoding = lilt_tokenizer(text=question, text_pair=words, boxes=boxes, add_special_tokens=True, return_tensors="pt")
119
+
120
+ outputs = model(**encoding)
121
+
122
+ answer_start_index = outputs.start_logits.argmax()
123
+ answer_end_index = outputs.end_logits.argmax()
124
+
125
+ predict_answer_tokens = encoding.input_ids[0, answer_start_index: answer_end_index + 1]
126
+ predict_answer = lilt_tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
127
+ return {
128
+ "score": "n/a",
129
+ "answer": predict_answer,
130
+ # "word_ids": element
131
+ }
132
+
133
+
134
+ def run_donut(question, document):
135
+
136
+ # nlp = pipeline(
137
+ # "document-question-answering",
138
+ # model="naver-clova-ix/donut-base-finetuned-docvqa",
139
+ # )
140
+ #
141
+ # result = nlp(document.context["image"][0][0], question)[0]
142
+ # # [{'score': 0.9999411106109619, 'answer': 'LETTER OF CREDIT', 'start': 106, 'end': 108}]
143
+ # return {
144
+ # "score": result["score"],
145
+ # "answer": result["answer"],
146
+ # "word_ids": [result["start"], result["end"]],
147
+ # "page": 0
148
+ # }
149
+
150
+ donut_processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
151
+ donut_model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
152
+ # prepare encoder inputs
153
+ pixel_values = donut_processor(document.context["image"][0][0], return_tensors="pt").pixel_values
154
+
155
+ # prepare decoder inputs
156
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
157
+ prompt = task_prompt.replace("{user_input}", question)
158
+ decoder_input_ids = donut_processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
159
+
160
+ # generate answer
161
+ outputs = donut_model.generate(
162
+ pixel_values,
163
+ decoder_input_ids=decoder_input_ids,
164
+ max_length=donut_model.decoder.config.max_position_embeddings,
165
+ early_stopping=True,
166
+ pad_token_id=donut_processor.tokenizer.pad_token_id,
167
+ eos_token_id=donut_processor.tokenizer.eos_token_id,
168
+ use_cache=True,
169
+ num_beams=1,
170
+ bad_words_ids=[[donut_processor.tokenizer.unk_token_id]],
171
+ return_dict_in_generate=True,
172
+ )
173
+ import re
174
+ # postprocess
175
+ sequence = donut_processor.batch_decode(outputs.sequences)[0]
176
+ sequence = sequence.replace(donut_processor.tokenizer.eos_token, "").replace(donut_processor.tokenizer.pad_token, "")
177
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
178
+
179
+ result = donut_processor.token2json(sequence)
180
+ return {
181
+ "score": "n/a",
182
+ "answer": result["answer"],
183
+ # "word_ids": element
184
+ }
185
+
186
+
187
+ def run_pipeline(model, question, document, top_k):
188
+ """ Run pipeline selected by the user.
189
+ :return: expect an object like
190
+ [{'score': 0.251716673374176, 'answer': 'CREDIT', 'word_ids': [38], 'page': 0},
191
+ {'score': 0.15292450785636902, 'answer': 'LETTER OF CREDIT', 'word_ids': [37, 38], 'page': 0},
192
+ {'score': 0.009600160643458366, 'answer': 'Payment Tens LETTER OF CREDIT', 'word_ids': [36, 37, 38], 'page': 0}]
193
+ """
194
+ if model == "Textract Query":
195
+ return run_textract_query(question, document)
196
+ elif model == "LiLT":
197
+ return run_lilt_model(question, document)
198
+ elif model == "LayoutLM FineTuned":
199
+ return run_layoutlm_finetuned(question=question, document=document)
200
+ elif model == "Donut":
201
+ return run_donut(question=question, document=document)
202
+ else:
203
+ return {"answer": "model not found", "score": "n/a"}
204
+
205
+
206
+
207
+ def process_path(path):
208
+ error = None
209
+ if path:
210
+ try:
211
+ document = load_document(path)
212
+ return (
213
+ document,
214
+ gr.update(visible=True, value=document.preview),
215
+ gr.update(visible=True),
216
+ gr.update(visible=False, value=None),
217
+ gr.update(visible=False, value=None),
218
+ None,
219
+ )
220
+ except Exception as e:
221
+ traceback.print_exc()
222
+ error = str(e)
223
+ return (
224
+ None,
225
+ gr.update(visible=False, value=None),
226
+ gr.update(visible=False),
227
+ gr.update(visible=False, value=None),
228
+ gr.update(visible=False, value=None),
229
+ gr.update(visible=True, value=error) if error is not None else None,
230
+ None,
231
+ )
232
+
233
+ def process_upload(file):
234
+ if file:
235
+ return process_path(file.name)
236
+ else:
237
+ return (
238
+ None,
239
+ gr.update(visible=False, value=None),
240
+ gr.update(visible=False),
241
+ gr.update(visible=False, value=None),
242
+ gr.update(visible=False, value=None),
243
+ None,
244
+ )
245
+
246
+
247
+ def lift_word_boxes(document, page):
248
+ return document.context["image"][page][1]
249
+
250
+
251
+ def expand_bbox(word_boxes):
252
+ if len(word_boxes) == 0:
253
+ return None
254
+
255
+ min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
256
+ min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
257
+ return [min_x, min_y, max_x, max_y]
258
+
259
+
260
+ # LayoutLM boxes are normalized to 0, 1000
261
+ def normalize_bbox(box, width, height, padding=0.005):
262
+ min_x, min_y, max_x, max_y = [c / 1000 for c in box]
263
+ if padding != 0:
264
+ min_x = max(0, min_x - padding)
265
+ min_y = max(0, min_y - padding)
266
+ max_x = min(max_x + padding, 1)
267
+ max_y = min(max_y + padding, 1)
268
+ return [min_x * width, min_y * height, max_x * width, max_y * height]
269
+
270
+
271
+ def process_question(question, document, model=list(CHECKPOINTS.keys())[0]):
272
+ prediction = run_pipeline(model, question, document, 3)
273
+ pages = [x.copy().convert("RGB") for x in document.preview]
274
+ text_value = prediction["answer"]
275
+ if "word_ids" in prediction:
276
+ image = pages[prediction["page"]]
277
+ draw = ImageDraw.Draw(image, "RGBA")
278
+ word_boxes = lift_word_boxes(document, prediction["page"])
279
+ x1, y1, x2, y2 = normalize_bbox(
280
+ expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
281
+ image.width,
282
+ image.height,
283
+ )
284
+ draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
285
+
286
+ return (
287
+ gr.update(visible=True, value=pages),
288
+ gr.update(visible=True, value=prediction),
289
+ gr.update(
290
+ visible=True,
291
+ value=text_value,
292
+ ),
293
+ )
294
+
295
+
296
+ def load_example_document(img, question, model):
297
+ if img is not None:
298
+ document = ImageDocument(Image.fromarray(img), get_ocr_reader())
299
+ preview, answer, answer_text = process_question(question, document, model)
300
+ return document, question, preview, gr.update(visible=True), answer, answer_text
301
+ else:
302
+ return None, None, None, gr.update(visible=False), None, None
303
+
304
+
305
+ CSS = """
306
+ #question input {
307
+ font-size: 16px;
308
+ }
309
+ #url-textbox {
310
+ padding: 0 !important;
311
+ }
312
+ #short-upload-box .w-full {
313
+ min-height: 10rem !important;
314
+ }
315
+ /* I think something like this can be used to re-shape
316
+ * the table
317
+ */
318
+ /*
319
+ .gr-samples-table tr {
320
+ display: inline;
321
+ }
322
+ .gr-samples-table .p-2 {
323
+ width: 100px;
324
+ }
325
+ */
326
+ #select-a-file {
327
+ width: 100%;
328
+ }
329
+ #file-clear {
330
+ padding-top: 2px !important;
331
+ padding-bottom: 2px !important;
332
+ padding-left: 8px !important;
333
+ padding-right: 8px !important;
334
+ margin-top: 10px;
335
+ }
336
+ .gradio-container .gr-button-primary {
337
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
338
+ border: 1px solid #B0DCCC;
339
+ border-radius: 8px;
340
+ color: #1B8700;
341
+ }
342
+ .gradio-container.dark button#submit-button {
343
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
344
+ border: 1px solid #B0DCCC;
345
+ border-radius: 8px;
346
+ color: #1B8700
347
+ }
348
+
349
+ table.gr-samples-table tr td {
350
+ border: none;
351
+ outline: none;
352
+ }
353
+
354
+ table.gr-samples-table tr td:first-of-type {
355
+ width: 0%;
356
+ }
357
+
358
+ div#short-upload-box div.absolute {
359
+ display: none !important;
360
+ }
361
+
362
+ gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
363
+ gap: 0px 2%;
364
+ }
365
+
366
+ gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
367
+ gap: 0px;
368
+ }
369
+
370
+ gradio-app h2, .gradio-app h2 {
371
+ padding-top: 10px;
372
+ }
373
+
374
+ #answer {
375
+ overflow-y: scroll;
376
+ color: white;
377
+ background: #666;
378
+ border-color: #666;
379
+ font-size: 20px;
380
+ font-weight: bold;
381
+ }
382
+
383
+ #answer span {
384
+ color: white;
385
+ }
386
+
387
+ #answer textarea {
388
+ color:white;
389
+ background: #777;
390
+ border-color: #777;
391
+ font-size: 18px;
392
+ }
393
+
394
+ #url-error input {
395
+ color: red;
396
+ }
397
+ """
398
+
399
+ examples = [
400
+
401
+ [
402
+ "scenario-1.png",
403
+ "What is the final consignee?",
404
+ ],
405
+ [
406
+ "scenario-1.png",
407
+ "What are the payment terms?",
408
+ ],
409
+ [
410
+ "scenario-2.png",
411
+ "What is the actual manufacturer?",
412
+ ],
413
+ [
414
+ "scenario-3.png",
415
+ 'What is the "ship to" destination?',
416
+ ],
417
+ [
418
+ "scenario-4.png",
419
+ 'What is the color?',
420
+ ],
421
+ [
422
+ "scenario-5.png",
423
+ 'What is the "said to contain"?',
424
+ ],
425
+ [
426
+ "scenario-5.png",
427
+ 'What is the "Net Weight"?',
428
+ ],
429
+ [
430
+ "scenario-5.png",
431
+ 'What is the "Freight Collect"?',
432
+ ],
433
+ [
434
+ "bill_of_lading_1.png",
435
+ "What is the shipper?",
436
+ ],
437
+ [
438
+ "bill_of_lading_1.png",
439
+ "What is the consignee?",
440
+ ],
441
+ [
442
+ "bill_of_lading_1.png",
443
+ "What is the consignee id?",
444
+ ],
445
+ [
446
+ "bill_of_lading_1.png",
447
+ "What is the carrier id?",
448
+ ],
449
+ [
450
+ "bill_of_lading_1.png",
451
+ "What is the description of the products?",
452
+ ],
453
+ [
454
+ "bill_of_lading_1.png",
455
+ "What is the quantity of the products?",
456
+ ],
457
+ ]
458
+
459
+ with gr.Blocks(css=CSS) as demo:
460
+ gr.Markdown("# Document Query Engine")
461
+ gr.Markdown(
462
+ "Original version comes from DocQuery [here](https://huggingface.co/spaces/impira/docquery) (created by [Impira](https://impira.com?utm_source=huggingface&utm_medium=referral&utm_campaign=docquery_space))"
463
+ )
464
+
465
+ document = gr.Variable()
466
+ example_question = gr.Textbox(visible=False)
467
+ example_image = gr.Image(visible=False)
468
+
469
+ with gr.Row(equal_height=True):
470
+ with gr.Column():
471
+ with gr.Row():
472
+ gr.Markdown("## 1. Select a file", elem_id="select-a-file")
473
+ img_clear_button = gr.Button(
474
+ "Clear", variant="secondary", elem_id="file-clear", visible=False
475
+ )
476
+ image = gr.Gallery(visible=False)
477
+ upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
478
+ gr.Examples(
479
+ examples=examples,
480
+ inputs=[example_image, example_question],
481
+ )
482
+
483
+ with gr.Column() as col:
484
+ gr.Markdown("## 2. Ask a question")
485
+ question = gr.Textbox(
486
+ label="Question",
487
+ placeholder="e.g. What is the invoice number?",
488
+ lines=1,
489
+ max_lines=1,
490
+ )
491
+ model = gr.Radio(
492
+ choices=list(CHECKPOINTS.keys()),
493
+ value=list(CHECKPOINTS.keys())[0],
494
+ label="Model",
495
+ )
496
+
497
+ with gr.Row():
498
+ clear_button = gr.Button("Clear", variant="secondary")
499
+ submit_button = gr.Button(
500
+ "Submit", variant="primary", elem_id="submit-button"
501
+ )
502
+ with gr.Column():
503
+ output_text = gr.Textbox(
504
+ label="Top Answer", visible=False, elem_id="answer"
505
+ )
506
+ output = gr.JSON(label="Output", visible=False)
507
+
508
+ for cb in [img_clear_button, clear_button]:
509
+ cb.click(
510
+ lambda _: (
511
+ gr.update(visible=False, value=None),
512
+ None,
513
+ gr.update(visible=False, value=None),
514
+ gr.update(visible=False, value=None),
515
+ gr.update(visible=False),
516
+ None,
517
+ None,
518
+ None,
519
+ gr.update(visible=False, value=None),
520
+ None,
521
+ ),
522
+ inputs=clear_button,
523
+ outputs=[
524
+ image,
525
+ document,
526
+ output,
527
+ output_text,
528
+ img_clear_button,
529
+ example_image,
530
+ upload,
531
+ question,
532
+ ],
533
+ )
534
+
535
+ upload.change(
536
+ fn=process_upload,
537
+ inputs=[upload],
538
+ outputs=[document, image, img_clear_button, output, output_text],
539
+ )
540
+
541
+ question.submit(
542
+ fn=process_question,
543
+ inputs=[question, document, model],
544
+ outputs=[image, output, output_text],
545
+ )
546
+
547
+ submit_button.click(
548
+ process_question,
549
+ inputs=[question, document, model],
550
+ outputs=[image, output, output_text],
551
+ )
552
+
553
+ model.change(
554
+ process_question,
555
+ inputs=[question, document, model],
556
+ outputs=[image, output, output_text],
557
+ )
558
+
559
+ example_image.change(
560
+ fn=load_example_document,
561
+ inputs=[example_image, example_question, model],
562
+ outputs=[document, question, image, img_clear_button, output, output_text],
563
+ )
564
+
565
+ if __name__ == "__main__":
566
+ demo.launch(enable_queue=False)
bill_of_lading_1.png ADDED
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ poppler-utils
2
+ tesseract-ocr
3
+ chromium
4
+ chromium-driver
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ tesseract-ocr
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torc
2
+ docquery[web,donut]
3
+ transformers
4
+ gradio
5
+ boto3
6
+ pillow
7
+
scenario-1.png ADDED
scenario-2.png ADDED
scenario-3.png ADDED
scenario-4.png ADDED
scenario-5.png ADDED