Geewook Kim commited on
Commit
ebf0b03
·
1 Parent(s): 582d0f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -7
app.py CHANGED
@@ -25,14 +25,7 @@ def demo_process_vqa(input_img, question):
25
  def demo_process(input_img):
26
  global pretrained_model, task_prompt, task_name
27
  input_img = Image.fromarray(input_img)
28
-
29
- pretrained_model = DonutModel.from_pretrained(args.pretrained_path, max_length=512)
30
- pretrained_model.encoder.to(torch.bfloat16)
31
- pretrained_model.eval()
32
-
33
  output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
34
-
35
- del pretrained_model
36
  return output
37
 
38
 
@@ -55,6 +48,10 @@ if __name__ == "__main__":
55
  if args.sample_img_path:
56
  example_sample.append(args.sample_img_path)
57
 
 
 
 
 
58
  demo = gr.Interface(
59
  fn=demo_process_vqa if task_name == "docvqa" else demo_process,
60
  inputs=["image", "text"] if task_name == "docvqa" else "image",
 
25
  def demo_process(input_img):
26
  global pretrained_model, task_prompt, task_name
27
  input_img = Image.fromarray(input_img)
 
 
 
 
 
28
  output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
 
 
29
  return output
30
 
31
 
 
48
  if args.sample_img_path:
49
  example_sample.append(args.sample_img_path)
50
 
51
+ pretrained_model = DonutModel.from_pretrained(args.pretrained_path, max_length=128)
52
+ pretrained_model.encoder.to(torch.bfloat16)
53
+ pretrained_model.eval()
54
+
55
  demo = gr.Interface(
56
  fn=demo_process_vqa if task_name == "docvqa" else demo_process,
57
  inputs=["image", "text"] if task_name == "docvqa" else "image",