laverdes commited on
Commit
43a5321
·
1 Parent(s): 78588fe

feat: add DonutProcessor and predict method

Browse files
Files changed (1) hide show
  1. app.py +45 -7
app.py CHANGED
@@ -2,14 +2,51 @@ import torch
2
  import streamlit as st
3
 
4
  from PIL import Image
5
- from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig # , DonutProcessor
6
 
7
 
8
- def demo_process(input_img):
9
- global pretrained_model, task_prompt # , task_name
10
- # input_img = Image.fromarray(input_img)
11
- output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
12
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  task_prompt = f"<s>"
15
 
@@ -30,10 +67,11 @@ image = Image.open(f"./img/receipt-{receipt}.jpg")
30
  st.image(image, caption='Your target receipt')
31
 
32
  st.text(f'baking the 🍩...')
 
33
  pretrained_model = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
34
  pretrained_model.encoder.to(torch.bfloat16)
35
  pretrained_model.eval()
36
 
37
  st.text(f'parsing receipt..')
38
- parsed_receipt_info = demo_process(image)
39
  st.text(f'\nRaw output:\n{parsed_receipt_info}')
 
2
  import streamlit as st
3
 
4
  from PIL import Image
5
+ from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig , DonutProcessor
6
 
7
 
8
+ def run_prediction(sample):
9
+ global pretrained_model, processor, task_prompt
10
+ if isinstance(sample, dict):
11
+ # prepare inputs
12
+ pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0)
13
+ else: # sample is an image
14
+ # prepare encoder inputs
15
+ pixel_values = processor(image, return_tensors="pt").pixel_values
16
+
17
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
18
+
19
+ # run inference
20
+ outputs = pretrained_model.generate(
21
+ pixel_values.to(device),
22
+ decoder_input_ids=decoder_input_ids.to(device),
23
+ max_length=pretrained_model.decoder.config.max_position_embeddings,
24
+ early_stopping=True,
25
+ pad_token_id=processor.tokenizer.pad_token_id,
26
+ eos_token_id=processor.tokenizer.eos_token_id,
27
+ use_cache=True,
28
+ num_beams=1,
29
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
30
+ return_dict_in_generate=True,
31
+ )
32
+
33
+ # process output
34
+ prediction = processor.batch_decode(outputs.sequences)[0]
35
+
36
+ # post-processing
37
+ if "cord" in task_prompt:
38
+ prediction = prediction.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
39
+ prediction = re.sub(r"<.*?>", "", prediction, count=1).strip() # remove first task start token
40
+ prediction = processor.token2json(prediction)
41
+
42
+ # load reference target
43
+ if isinstance(sample, dict):
44
+ target = processor.token2json(sample["target_sequence"])
45
+ else:
46
+ target = "<not_provided>"
47
+
48
+ return prediction, target
49
+
50
 
51
  task_prompt = f"<s>"
52
 
 
67
  st.image(image, caption='Your target receipt')
68
 
69
  st.text(f'baking the 🍩...')
70
+ processor = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
71
  pretrained_model = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
72
  pretrained_model.encoder.to(torch.bfloat16)
73
  pretrained_model.eval()
74
 
75
  st.text(f'parsing receipt..')
76
+ parsed_receipt_info = run_prediction(image)
77
  st.text(f'\nRaw output:\n{parsed_receipt_info}')