passport-recog / server.py
ethanrom's picture
Update server.py
528f442
raw
history blame
5.08 kB
import json
from fastapi import FastAPI, File, UploadFile, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from paddleocr import PaddleOCR
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from passporteye import read_mrz
from pydantic.v1 import BaseModel as v1BaseModel
from pydantic.v1 import Field
from pydantic import BaseModel
from typing import Any, Optional, Dict, List
from huggingface_hub import InferenceClient
from langchain.llms.base import LLM
import os
HF_token = os.getenv("apiToken")
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
hf_token = HF_token
kwargs = {"max_new_tokens":500, "temperature":0.1, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True}
class KwArgsModel(v1BaseModel):
kwargs: Dict[str, Any] = Field(default_factory=dict)
class CustomInferenceClient(LLM, KwArgsModel):
model_name: str
inference_client: InferenceClient
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
inference_client = InferenceClient(model=model_name, token=hf_token)
super().__init__(
model_name=model_name,
hf_token=hf_token,
kwargs=kwargs,
inference_client=inference_client
)
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False)
response = ''.join(response_gen)
return response
@property
def _llm_type(self) -> str:
return "custom"
@property
def _identifying_params(self) -> dict:
return {"model_name": self.model_name}
app = FastAPI(title="Passport Recognition API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
ocr = PaddleOCR(use_angle_cls=True, lang='en')
template = """below is poorly read ocr result of a passport.
OCR Result:
{ocr_result}
Fill the below catergories using the OCR Results. you can correct spellings and make other adujustments. Dates should be in 01-JAN-2000 format.
"countryName": "",
"dateOfBirth": "",
"dateOfExpiry": "",
"dateOfIssue": "",
"documentNumber": "",
"givenNames": "",
"name": "",
"surname": "",
"mrz": ""
json output:
"""
prompt = PromptTemplate(template=template, input_variables=["ocr_result"])
class MRZData(BaseModel):
date_of_birth: str
expiration_date: str
type: str
number: str
names: str
country: str
check_number: str
check_date_of_birth: str
check_expiration_date: str
check_composite: str
check_personal_number: str
valid_number: bool
valid_date_of_birth: bool
valid_expiration_date: bool
valid_composite: bool
valid_personal_number: bool
method: str
class OCRData(BaseModel):
countryName: str
dateOfBirth: str
dateOfExpiry: str
dateOfIssue: str
documentNumber: str
givenNames: str
name: str
surname: str
mrz: str
class ResponseData(BaseModel):
documentName: str
errorCode: int
mrz: MRZData
ocr: OCRData
status: str
def create_response_data(mrz, ocr_data):
return ResponseData(
documentName="Passport",
errorCode=0,
mrz=MRZData(**mrz),
ocr=OCRData(**ocr_data),
status="ok"
)
@app.post("/recognize_passport", response_model=ResponseData, status_code=status.HTTP_201_CREATED)
async def recognize_passport(image: UploadFile = File(...)):
"""Passport information extraction from a provided image file."""
try:
image_bytes = await image.read()
mrz = read_mrz(image_bytes)
img_path = 'image.jpg'
with open(img_path, 'wb') as f:
f.write(image_bytes)
result = ocr.ocr(img_path, cls=True)
json_result = []
for idx in range(len(result)):
res = result[idx]
for line in res:
coordinates, text_with_confidence = line
text, confidence = text_with_confidence
json_result.append({
'coordinates': coordinates,
'text': text,
'confidence': confidence
})
llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs)
llm_chain = LLMChain(prompt=prompt, llm=llm)
response_str = llm_chain.run(ocr_result=json_result)
response_str = response_str.rstrip("</s>")
#print(response_str)
ocr_data = json.loads(response_str)
return create_response_data(mrz.to_dict(), ocr_data)
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Internal server error: {str(e)}"
) from e