File size: 4,759 Bytes
8700a34
836458e
6bbd3ca
 
c39e604
 
af17670
 
86a0b7a
 
6bbd3ca
86a0b7a
836458e
86a0b7a
6bbd3ca
86a0b7a
 
6bbd3ca
86a0b7a
 
6bbd3ca
86a0b7a
 
6bbd3ca
 
 
 
a82199b
6bbd3ca
 
c39e604
 
 
86a0b7a
 
41d335c
86a0b7a
 
41d335c
86a0b7a
 
 
f198fb3
f8ec4b3
86a0b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8ec4b3
ed563a5
 
f8ec4b3
 
 
 
2181fee
f8ec4b3
 
ed563a5
 
 
 
 
 
 
8700a34
ed563a5
 
 
 
c39e604
ed563a5
 
41d335c
ed563a5
 
41d335c
ed563a5
 
f198fb3
ed563a5
 
 
 
 
 
 
 
 
 
 
 
 
 
6bbd3ca
 
8700a34
f66f82d
0c14995
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import fitz
from fastapi import FastAPI, File, UploadFile, Form, Request, Response
from fastapi.responses import JSONResponse
from transformers import pipeline
from PIL import Image
from io import BytesIO
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
import torch
import re

from transformers import DonutProcessor, VisionEncoderDecoderModel

app = FastAPI()

processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

@app.post("/donutQA/")
async def donut_question_answering(
    file: UploadFile = File(...),
    questions: str = Form(...),
):
    try:
        # Read the uploaded file as bytes
        contents = await file.read()

        # Open the image using PIL
        image = Image.open(BytesIO(contents))

        # Split the questions into a list
        question_list = questions.split(',')

        # Process document with Donut model for each question
        answers = process_document(image, question_list)

        # Return a dictionary with questions and corresponding answers
        result_dict = dict(zip(question_list, answers))
        return result_dict

    except Exception as e:
        return {"error": f"Error processing file: {str(e)}"}

def process_document(image, questions):
    # prepare encoder inputs
    pixel_values = processor(image, return_tensors="pt").pixel_values
    
    # prepare decoder inputs
    task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
    
    # Initialize a list to store answers for each question
    answers = []

    # Process each question
    for question in questions:
        prompt = task_prompt.replace("{user_input}", question)
        decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
              
        # generate answer
        outputs = model.generate(
            pixel_values.to(device),
            decoder_input_ids=decoder_input_ids.to(device),
            max_length=model.decoder.config.max_position_embeddings,
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )
        
        # postprocess
        sequence = processor.batch_decode(outputs.sequences)[0]
        sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
        sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token

        # Append the answer to the list
        answers.append(processor.token2json(sequence))

    return answers

@app.post("/pdfQA/", description=description)
async def pdf_question_answering(
    file: UploadFile = File(...),
    questions: str = Form(...),
):
    try:
        # Read the uploaded file as bytes
        contents = await file.read()

        # Initialize an empty string to store the text content of the PDF
        all_text = ""

        # Use PyMuPDF to process the PDF and extract text
        pdf_document = fitz.open_from_bytes(contents)
        
        # Loop through each page and perform OCR
        for page_num in range(pdf_document.page_count):
            page = pdf_document.load_page(page_num)
            print(f"Processing page {page_num + 1}...")
            text = page.get_text()
            all_text += text + '\n'

        # Print or do something with the collected text
        print(all_text)

        # List of questions
        question_list = questions.split(',')

        # Initialize an empty dictionary to store questions and answers
        qa_dict = {}

        # Get answers for each question with the same context
        for question in question_list:
            result = nlp_qa({
                'question': question,
                'context': all_text
            })

            # Access the 'answer' key from the result
            answer = result['answer']

            # Store the question and answer in the dictionary
            qa_dict[question] = answer

        return qa_dict

    except Exception as e:
        return JSONResponse(content=f"Error processing PDF file: {str(e)}", status_code=500)

# Set up CORS middleware
origins = ["*"]  # or specify your list of allowed origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)