Spaces:
Running
Running
File size: 4,759 Bytes
8700a34 836458e 6bbd3ca c39e604 af17670 86a0b7a 6bbd3ca 86a0b7a 836458e 86a0b7a 6bbd3ca 86a0b7a 6bbd3ca 86a0b7a 6bbd3ca 86a0b7a 6bbd3ca a82199b 6bbd3ca c39e604 86a0b7a 41d335c 86a0b7a 41d335c 86a0b7a f198fb3 f8ec4b3 86a0b7a f8ec4b3 ed563a5 f8ec4b3 2181fee f8ec4b3 ed563a5 8700a34 ed563a5 c39e604 ed563a5 41d335c ed563a5 41d335c ed563a5 f198fb3 ed563a5 6bbd3ca 8700a34 f66f82d 0c14995 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import fitz
from fastapi import FastAPI, File, UploadFile, Form, Request, Response
from fastapi.responses import JSONResponse
from transformers import pipeline
from PIL import Image
from io import BytesIO
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
import torch
import re
from transformers import DonutProcessor, VisionEncoderDecoderModel
app = FastAPI()
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
@app.post("/donutQA/")
async def donut_question_answering(
file: UploadFile = File(...),
questions: str = Form(...),
):
try:
# Read the uploaded file as bytes
contents = await file.read()
# Open the image using PIL
image = Image.open(BytesIO(contents))
# Split the questions into a list
question_list = questions.split(',')
# Process document with Donut model for each question
answers = process_document(image, question_list)
# Return a dictionary with questions and corresponding answers
result_dict = dict(zip(question_list, answers))
return result_dict
except Exception as e:
return {"error": f"Error processing file: {str(e)}"}
def process_document(image, questions):
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
# Initialize a list to store answers for each question
answers = []
# Process each question
for question in questions:
prompt = task_prompt.replace("{user_input}", question)
decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
# Append the answer to the list
answers.append(processor.token2json(sequence))
return answers
@app.post("/pdfQA/", description=description)
async def pdf_question_answering(
file: UploadFile = File(...),
questions: str = Form(...),
):
try:
# Read the uploaded file as bytes
contents = await file.read()
# Initialize an empty string to store the text content of the PDF
all_text = ""
# Use PyMuPDF to process the PDF and extract text
pdf_document = fitz.open_from_bytes(contents)
# Loop through each page and perform OCR
for page_num in range(pdf_document.page_count):
page = pdf_document.load_page(page_num)
print(f"Processing page {page_num + 1}...")
text = page.get_text()
all_text += text + '\n'
# Print or do something with the collected text
print(all_text)
# List of questions
question_list = questions.split(',')
# Initialize an empty dictionary to store questions and answers
qa_dict = {}
# Get answers for each question with the same context
for question in question_list:
result = nlp_qa({
'question': question,
'context': all_text
})
# Access the 'answer' key from the result
answer = result['answer']
# Store the question and answer in the dictionary
qa_dict[question] = answer
return qa_dict
except Exception as e:
return JSONResponse(content=f"Error processing PDF file: {str(e)}", status_code=500)
# Set up CORS middleware
origins = ["*"] # or specify your list of allowed origins
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
|