Spaces:
Runtime error
Runtime error
File size: 6,398 Bytes
d49f9c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
from fastapi import APIRouter, File, UploadFile, Form, HTTPException, status
from fastapi.responses import JSONResponse
from config import settings
from PIL import Image
import urllib.request
from io import BytesIO
import utils
import os
import time
from functools import lru_cache
from paddleocr import PaddleOCR
from pdf2image import convert_from_bytes
import io
import json
from routers.data_utils import merge_data
from routers.data_utils import store_data
import motor.motor_asyncio
from typing import Optional
from pymongo import ASCENDING
from pymongo.errors import DuplicateKeyError
router = APIRouter()
client = None
db = None
async def create_unique_index(collection, *fields):
index_fields = [(field, 1) for field in fields]
return await collection.create_index(index_fields, unique=True)
async def create_ttl_index(db, collection_name, field, expire_after_seconds):
# Get a reference to your collection
collection = db[collection_name]
# Create an index on the specified field
index_result = await collection.create_index([(field, ASCENDING)], expireAfterSeconds=expire_after_seconds)
print(f"TTL index created or already exists: {index_result}")
@router.on_event("startup")
async def startup_event():
if "MONGODB_URL" in os.environ:
global client
global db
client = motor.motor_asyncio.AsyncIOMotorClient(os.environ.get("MONGODB_URL"))
db = client.chatgpt_plugin
index_result = await create_unique_index(db['uploads'], 'receipt_key')
print(f"Unique index created or already exists: {index_result}")
index_result = await create_unique_index(db['receipts'], 'user', 'receipt_key')
print(f"Unique index created or already exists: {index_result}")
await create_ttl_index(db, 'uploads', 'created_at', 15*60)
print("Connected to MongoDB from OCR!")
@router.on_event("shutdown")
async def shutdown_event():
if "MONGODB_URL" in os.environ:
global client
client.close()
@lru_cache(maxsize=1)
def load_ocr_model():
model = PaddleOCR(use_angle_cls=True, lang='en')
return model
def invoke_ocr(doc, content_type):
worker_pid = os.getpid()
print(f"Handling OCR request with worker PID: {worker_pid}")
start_time = time.time()
model = load_ocr_model()
bytes_img = io.BytesIO()
format_img = "JPEG"
if content_type == "image/png":
format_img = "PNG"
doc.save(bytes_img, format=format_img)
bytes_data = bytes_img.getvalue()
bytes_img.close()
result = model.ocr(bytes_data, cls=True)
values = []
for idx in range(len(result)):
res = result[idx]
for line in res:
values.append(line)
values = merge_data(values)
end_time = time.time()
processing_time = end_time - start_time
print(f"OCR done, worker PID: {worker_pid}")
return values, processing_time
@router.post("/ocr")
async def run_ocr(file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None),
post_processing: Optional[bool] = Form(False), sparrow_key: str = Form(None)):
if sparrow_key != settings.sparrow_key:
return {"error": "Invalid Sparrow key."}
result = None
if file:
if file.content_type in ["image/jpeg", "image/jpg", "image/png"]:
doc = Image.open(BytesIO(await file.read()))
elif file.content_type == "application/pdf":
pdf_bytes = await file.read()
pages = convert_from_bytes(pdf_bytes, 300)
doc = pages[0]
else:
return {"error": "Invalid file type. Only JPG/PNG images and PDF are allowed."}
result, processing_time = invoke_ocr(doc, file.content_type)
utils.log_stats(settings.ocr_stats_file, [processing_time, file.filename])
print(f"Processing time OCR: {processing_time:.2f} seconds")
if post_processing and "MONGODB_URL" in os.environ:
print("Postprocessing...")
try:
result = await store_data(result, db)
except DuplicateKeyError:
return HTTPException(status_code=400, detail=f"Duplicate data.")
print(f"Stored data with key: {result}")
elif image_url:
# test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
# test PDF: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/receipts/2021/us/bestbuy-20211211_006.pdf
with urllib.request.urlopen(image_url) as response:
content_type = response.info().get_content_type()
if content_type in ["image/jpeg", "image/jpg", "image/png"]:
doc = Image.open(BytesIO(response.read()))
elif content_type == "application/octet-stream":
pdf_bytes = response.read()
pages = convert_from_bytes(pdf_bytes, 300)
doc = pages[0]
else:
return {"error": "Invalid file type. Only JPG/PNG images and PDF are allowed."}
result, processing_time = invoke_ocr(doc, content_type)
# parse file name from url
file_name = image_url.split("/")[-1]
utils.log_stats(settings.ocr_stats_file, [processing_time, file_name])
print(f"Processing time OCR: {processing_time:.2f} seconds")
if post_processing and "MONGODB_URL" in os.environ:
print("Postprocessing...")
try:
result = await store_data(result, db)
except DuplicateKeyError:
return HTTPException(status_code=400, detail=f"Duplicate data.")
print(f"Stored data with key: {result}")
else:
result = {"info": "No input provided"}
if result is None:
raise HTTPException(status_code=400, detail=f"Failed to process the input.")
return JSONResponse(status_code=status.HTTP_200_OK, content=result)
@router.get("/statistics")
async def get_statistics():
file_path = settings.ocr_stats_file
# Check if the file exists, and read its content
if os.path.exists(file_path):
with open(file_path, 'r') as file:
try:
content = json.load(file)
except json.JSONDecodeError:
content = []
else:
content = []
return content
|