import fitz from fastapi import FastAPI, File, UploadFile, Form, Request, Response from fastapi.responses import JSONResponse from transformers import pipeline from PIL import Image from io import BytesIO from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware import torch import re from transformers import DonutProcessor, VisionEncoderDecoderModel app = FastAPI() processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) @app.post("/donutQA/") async def donut_question_answering( file: UploadFile = File(...), questions: str = Form(...), ): try: # Read the uploaded file as bytes contents = await file.read() # Open the image using PIL image = Image.open(BytesIO(contents)) # Split the questions into a list question_list = questions.split(',') # Process document with Donut model for each question answers = process_document(image, question_list) # Return a dictionary with questions and corresponding answers result_dict = dict(zip(question_list, answers)) return result_dict except Exception as e: return {"error": f"Error processing file: {str(e)}"} def process_document(image, questions): # prepare encoder inputs pixel_values = processor(image, return_tensors="pt").pixel_values # prepare decoder inputs task_prompt = "{user_input}" # Initialize a list to store answers for each question answers = [] # Process each question for question in questions: prompt = task_prompt.replace("{user_input}", question) decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids # generate answer outputs = model.generate( pixel_values.to(device), decoder_input_ids=decoder_input_ids.to(device), max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) # postprocess sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token # Append the answer to the list answers.append(processor.token2json(sequence)) return answers @app.post("/pdfQA/", description=description) async def pdf_question_answering( file: UploadFile = File(...), questions: str = Form(...), ): try: # Read the uploaded file as bytes contents = await file.read() # Initialize an empty string to store the text content of the PDF all_text = "" # Use PyMuPDF to process the PDF and extract text pdf_document = fitz.open_from_bytes(contents) # Loop through each page and perform OCR for page_num in range(pdf_document.page_count): page = pdf_document.load_page(page_num) print(f"Processing page {page_num + 1}...") text = page.get_text() all_text += text + '\n' # Print or do something with the collected text print(all_text) # List of questions question_list = questions.split(',') # Initialize an empty dictionary to store questions and answers qa_dict = {} # Get answers for each question with the same context for question in question_list: result = nlp_qa({ 'question': question, 'context': all_text }) # Access the 'answer' key from the result answer = result['answer'] # Store the question and answer in the dictionary qa_dict[question] = answer return qa_dict except Exception as e: return JSONResponse(content=f"Error processing PDF file: {str(e)}", status_code=500) # Set up CORS middleware origins = ["*"] # or specify your list of allowed origins app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], )