document-vqa-v2 / main.py
MJobe's picture
Update main.py
86a0b7a
raw
history blame
4.76 kB
import fitz
from fastapi import FastAPI, File, UploadFile, Form, Request, Response
from fastapi.responses import JSONResponse
from transformers import pipeline
from PIL import Image
from io import BytesIO
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
import torch
import re
from transformers import DonutProcessor, VisionEncoderDecoderModel
app = FastAPI()
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
@app.post("/donutQA/")
async def donut_question_answering(
file: UploadFile = File(...),
questions: str = Form(...),
):
try:
# Read the uploaded file as bytes
contents = await file.read()
# Open the image using PIL
image = Image.open(BytesIO(contents))
# Split the questions into a list
question_list = questions.split(',')
# Process document with Donut model for each question
answers = process_document(image, question_list)
# Return a dictionary with questions and corresponding answers
result_dict = dict(zip(question_list, answers))
return result_dict
except Exception as e:
return {"error": f"Error processing file: {str(e)}"}
def process_document(image, questions):
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
# Initialize a list to store answers for each question
answers = []
# Process each question
for question in questions:
prompt = task_prompt.replace("{user_input}", question)
decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
# Append the answer to the list
answers.append(processor.token2json(sequence))
return answers
@app.post("/pdfQA/", description=description)
async def pdf_question_answering(
file: UploadFile = File(...),
questions: str = Form(...),
):
try:
# Read the uploaded file as bytes
contents = await file.read()
# Initialize an empty string to store the text content of the PDF
all_text = ""
# Use PyMuPDF to process the PDF and extract text
pdf_document = fitz.open_from_bytes(contents)
# Loop through each page and perform OCR
for page_num in range(pdf_document.page_count):
page = pdf_document.load_page(page_num)
print(f"Processing page {page_num + 1}...")
text = page.get_text()
all_text += text + '\n'
# Print or do something with the collected text
print(all_text)
# List of questions
question_list = questions.split(',')
# Initialize an empty dictionary to store questions and answers
qa_dict = {}
# Get answers for each question with the same context
for question in question_list:
result = nlp_qa({
'question': question,
'context': all_text
})
# Access the 'answer' key from the result
answer = result['answer']
# Store the question and answer in the dictionary
qa_dict[question] = answer
return qa_dict
except Exception as e:
return JSONResponse(content=f"Error processing PDF file: {str(e)}", status_code=500)
# Set up CORS middleware
origins = ["*"] # or specify your list of allowed origins
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)