Spaces:
Running
Running
import fitz | |
from fastapi import FastAPI, File, UploadFile, Form, Request, Response | |
from fastapi.responses import JSONResponse | |
from transformers import pipeline | |
from PIL import Image | |
from io import BytesIO | |
from starlette.middleware import Middleware | |
from starlette.middleware.cors import CORSMiddleware | |
import torch | |
import re | |
from transformers import DonutProcessor, VisionEncoderDecoderModel | |
app = FastAPI() | |
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") | |
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
async def donut_question_answering( | |
file: UploadFile = File(...), | |
questions: str = Form(...), | |
): | |
try: | |
# Read the uploaded file as bytes | |
contents = await file.read() | |
# Open the image using PIL | |
image = Image.open(BytesIO(contents)) | |
# Split the questions into a list | |
question_list = questions.split(',') | |
# Process document with Donut model for each question | |
answers = process_document(image, question_list) | |
# Return a dictionary with questions and corresponding answers | |
result_dict = dict(zip(question_list, answers)) | |
return result_dict | |
except Exception as e: | |
return {"error": f"Error processing file: {str(e)}"} | |
def process_document(image, questions): | |
# prepare encoder inputs | |
pixel_values = processor(image, return_tensors="pt").pixel_values | |
# prepare decoder inputs | |
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>" | |
# Initialize a list to store answers for each question | |
answers = [] | |
# Process each question | |
for question in questions: | |
prompt = task_prompt.replace("{user_input}", question) | |
decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids | |
# generate answer | |
outputs = model.generate( | |
pixel_values.to(device), | |
decoder_input_ids=decoder_input_ids.to(device), | |
max_length=model.decoder.config.max_position_embeddings, | |
early_stopping=True, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
eos_token_id=processor.tokenizer.eos_token_id, | |
use_cache=True, | |
num_beams=1, | |
bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
return_dict_in_generate=True, | |
) | |
# postprocess | |
sequence = processor.batch_decode(outputs.sequences)[0] | |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") | |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token | |
# Append the answer to the list | |
answers.append(processor.token2json(sequence)) | |
return answers | |
async def pdf_question_answering( | |
file: UploadFile = File(...), | |
questions: str = Form(...), | |
): | |
try: | |
# Read the uploaded file as bytes | |
contents = await file.read() | |
# Initialize an empty string to store the text content of the PDF | |
all_text = "" | |
# Use PyMuPDF to process the PDF and extract text | |
pdf_document = fitz.open_from_bytes(contents) | |
# Loop through each page and perform OCR | |
for page_num in range(pdf_document.page_count): | |
page = pdf_document.load_page(page_num) | |
print(f"Processing page {page_num + 1}...") | |
text = page.get_text() | |
all_text += text + '\n' | |
# Print or do something with the collected text | |
print(all_text) | |
# List of questions | |
question_list = questions.split(',') | |
# Initialize an empty dictionary to store questions and answers | |
qa_dict = {} | |
# Get answers for each question with the same context | |
for question in question_list: | |
result = nlp_qa({ | |
'question': question, | |
'context': all_text | |
}) | |
# Access the 'answer' key from the result | |
answer = result['answer'] | |
# Store the question and answer in the dictionary | |
qa_dict[question] = answer | |
return qa_dict | |
except Exception as e: | |
return JSONResponse(content=f"Error processing PDF file: {str(e)}", status_code=500) | |
# Set up CORS middleware | |
origins = ["*"] # or specify your list of allowed origins | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |