Andreas Lukito commited on
Commit
5e03bd0
·
1 Parent(s): ac734e5

Update sample code

Browse files
Files changed (1) hide show
  1. app.py +59 -4
app.py CHANGED
@@ -1,7 +1,62 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ import torch
4
+ import re
5
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
6
 
7
+ def load_and_preprocess_image(image, processor):
8
+ """
9
+ Load an image and preprocess it for the model.
10
+ """
11
+ pixel_values = processor(image, return_tensors="pt").pixel_values
12
+ return pixel_values
13
+
14
+ def generate_text_from_image(model, image, processor, device):
15
+ """
16
+ Generate text from an image using the trained model.
17
+ """
18
+ # Load and preprocess the image
19
+ pixel_values = load_and_preprocess_image(image, processor)
20
+ pixel_values = pixel_values.to(device)
21
+
22
+ # Generate output using model
23
+ model.eval()
24
+ with torch.no_grad():
25
+ task_prompt = "<s_receipt>" # <s_cord-v2> for v1
26
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
27
+ decoder_input_ids = decoder_input_ids.to(device)
28
+ generated_outputs = model.generate(
29
+ pixel_values,
30
+ decoder_input_ids=decoder_input_ids,
31
+ max_length=model.decoder.config.max_position_embeddings,
32
+ pad_token_id=processor.tokenizer.pad_token_id,
33
+ eos_token_id=processor.tokenizer.eos_token_id,
34
+ early_stopping=True,
35
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
36
+ return_dict_in_generate=True
37
+ )
38
+
39
+ # Decode generated output
40
+ decoded_text = processor.batch_decode(generated_outputs.sequences)[0]
41
+ decoded_text = decoded_text.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
42
+ decoded_text = re.sub(r"<.*?>", "", decoded_text, count=1).strip() # remove first task start token
43
+ decoded_text = processor.token2json(decoded_text)
44
+ return decoded_text
45
+
46
+
47
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
48
+ processor = DonutProcessor.from_pretrained("AdamCodd/donut-receipts-extract")
49
+ model = VisionEncoderDecoderModel.from_pretrained("AdamCodd/donut-receipts-extract")
50
+ model.to(device)
51
+
52
+
53
+ def process_image(image):
54
+ extracted_text = generate_text_from_image(model, image, processor, device)
55
+ return extracted_text
56
+
57
+
58
+ image = gr.Image(type='pil')
59
+ label = gr.Label()
60
+
61
+ intf = gr.Interface(fn=process_image, inputs=image, outputs=label)
62
+ intf.launch(inline=False)