Geewook Kim
commited on
Commit
·
ebf0b03
1
Parent(s):
582d0f3
Update app.py
Browse files
app.py
CHANGED
@@ -25,14 +25,7 @@ def demo_process_vqa(input_img, question):
|
|
25 |
def demo_process(input_img):
|
26 |
global pretrained_model, task_prompt, task_name
|
27 |
input_img = Image.fromarray(input_img)
|
28 |
-
|
29 |
-
pretrained_model = DonutModel.from_pretrained(args.pretrained_path, max_length=512)
|
30 |
-
pretrained_model.encoder.to(torch.bfloat16)
|
31 |
-
pretrained_model.eval()
|
32 |
-
|
33 |
output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
|
34 |
-
|
35 |
-
del pretrained_model
|
36 |
return output
|
37 |
|
38 |
|
@@ -55,6 +48,10 @@ if __name__ == "__main__":
|
|
55 |
if args.sample_img_path:
|
56 |
example_sample.append(args.sample_img_path)
|
57 |
|
|
|
|
|
|
|
|
|
58 |
demo = gr.Interface(
|
59 |
fn=demo_process_vqa if task_name == "docvqa" else demo_process,
|
60 |
inputs=["image", "text"] if task_name == "docvqa" else "image",
|
|
|
25 |
def demo_process(input_img):
|
26 |
global pretrained_model, task_prompt, task_name
|
27 |
input_img = Image.fromarray(input_img)
|
|
|
|
|
|
|
|
|
|
|
28 |
output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
|
|
|
|
|
29 |
return output
|
30 |
|
31 |
|
|
|
48 |
if args.sample_img_path:
|
49 |
example_sample.append(args.sample_img_path)
|
50 |
|
51 |
+
pretrained_model = DonutModel.from_pretrained(args.pretrained_path, max_length=128)
|
52 |
+
pretrained_model.encoder.to(torch.bfloat16)
|
53 |
+
pretrained_model.eval()
|
54 |
+
|
55 |
demo = gr.Interface(
|
56 |
fn=demo_process_vqa if task_name == "docvqa" else demo_process,
|
57 |
inputs=["image", "text"] if task_name == "docvqa" else "image",
|