uartimcs commited on
Commit
67c0efb
·
verified ·
1 Parent(s): 0b6738d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -26
app.py CHANGED
@@ -1,26 +1,29 @@
1
- import gradio as gr
2
- import argparse
3
- import torch
4
- from PIL import Image
5
- from donut import DonutModel
6
- def demo_process(input_img):
7
- global model, task_prompt, task_name
8
- input_img = Image.fromarray(input_img)
9
- output = model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
10
- return output
11
- parser = argparse.ArgumentParser()
12
- parser.add_argument("--task", type=str, default="Booking")
13
- parser.add_argument("--pretrained_path", type=str, default="result/train_booking/20241112_150925")
14
- args, left_argv = parser.parse_known_args()
15
- task_name = args.task
16
- task_prompt = f"<s_{task_name}>"
17
- model = DonutModel.from_pretrained("./result/train_booking/20241112_150925")
18
- if torch.cuda.is_available():
19
- model.half()
20
- device = torch.device("cuda")
21
- model.to(device)
22
- else:
23
- model.encoder.to(torch.bfloat16)
24
- model.eval()
25
- demo = gr.Interface(fn=demo_process,inputs="image",outputs="json", title=f"Donut 🍩 demonstration for `{task_name}` task",)
26
- demo.launch(debug=True)
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import torch
4
+ from PIL import Image
5
+ from donut import DonutModel
6
+ def demo_process(input_img):
7
+ global model, task_prompt, task_name
8
+ input_img = Image.fromarray(input_img)
9
+ output = model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
10
+ return output
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--task", type=str, default="Booking")
13
+ parser.add_argument("--pretrained_path", type=str, default="uartimcs/donut-booking-extract")
14
+ args, left_argv = parser.parse_known_args()
15
+ task_name = args.task
16
+ task_prompt = f"<s_{task_name}>"
17
+
18
+ image = Image.open("./sample-booking/CMA_150.jpg")
19
+ image.save("CMA_sample.jpg")
20
+ image = Image.open("./sample-booking/COSCO_150.jpg")
21
+ image.save("COSCO_sample.jpg")
22
+ image = Image.open("./sample-booking/ONEY_150.jpg")
23
+ image.save("ONEY_sample.jpg")
24
+
25
+
26
+ model = DonutModel.from_pretrained("uartimcs/donut-booking-extract")
27
+ model.eval()
28
+ demo = gr.Interface(fn=demo_process,inputs="image",outputs="json", title=f"Donut 🍩 demonstration for `{task_name}` task", examples=[["CMA_sample.jpg"], ["COSCO_sample.jpg"], ["ONEY_sample.jpg"]],)
29
+ demo.launch()