Iqra Ali commited on
Commit
25c5ef7
·
1 Parent(s): 2f61d8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import torch
3
  from PIL import Image
4
 
5
- #from donut import DonutModel
6
 
7
  def demo_process(input_img):
8
  global pretrained_model, task_prompt, task_name
@@ -12,25 +12,26 @@ def demo_process(input_img):
12
 
13
  task_prompt = f"<s_cord-v2>"
14
 
15
- image = Image.open("/content/SKMBT_75122072616550_Page_37_Image_0001.png")
16
  image.save("cord_sample_receipt1.png")
17
- image = Image.open("/content/SKMBT_75122072616550_Page_50_Image_0001.png")
18
  image.save("cord_sample_receipt2.png")
19
 
20
- #pretrained_model = DonutModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
21
- #pretrained_model.encoder.to(torch.bfloat16)
22
-
23
- model = torch.load("/content/drive/MyDrive/fast_job/DONUT_model/donut/model.pt")
24
- # Move model to GPU
25
- device = "cuda" if torch.cuda.is_available() else "cpu"
26
- model.to(device)
27
 
28
  demo = gr.Interface(
29
  fn=demo_process,
30
  inputs= gr.inputs.Image(type="pil"),
31
  outputs="json",
32
- title=f"Donut 🍩 demonstration for `Medical Prescription Dataset` task",
33
- description="""This model is trained with 200 medical prescription handwritten document images. <br>""",
 
 
 
 
 
34
  examples=[["cord_sample_receipt1.png"], ["cord_sample_receipt2.png"]],
35
  cache_examples=False,
36
  )
 
2
  import torch
3
  from PIL import Image
4
 
5
+ from donut import DonutModel
6
 
7
  def demo_process(input_img):
8
  global pretrained_model, task_prompt, task_name
 
12
 
13
  task_prompt = f"<s_cord-v2>"
14
 
15
+ image = Image.open("./sample_image_cord_test_receipt_00004.png")
16
  image.save("cord_sample_receipt1.png")
17
+ image = Image.open("./sample_image_cord_test_receipt_00012.png")
18
  image.save("cord_sample_receipt2.png")
19
 
20
+ pretrained_model = DonutModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
21
+ pretrained_model.encoder.to(torch.bfloat16)
22
+ pretrained_model.eval()
 
 
 
 
23
 
24
  demo = gr.Interface(
25
  fn=demo_process,
26
  inputs= gr.inputs.Image(type="pil"),
27
  outputs="json",
28
+ title=f"Donut 🍩 demonstration for `cord-v2` task",
29
+ description="""This model is trained with 800 Indonesian receipt images of CORD dataset. <br>
30
+ Demonstrations for other types of documents/tasks are available at https://github.com/clovaai/donut <br>
31
+ More CORD receipt images are available at https://huggingface.co/datasets/naver-clova-ix/cord-v2
32
+ More details are available at:
33
+ - Paper: https://arxiv.org/abs/2111.15664
34
+ - GitHub: https://github.com/clovaai/donut""",
35
  examples=[["cord_sample_receipt1.png"], ["cord_sample_receipt2.png"]],
36
  cache_examples=False,
37
  )