import json from fastapi import FastAPI, File, UploadFile, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from paddleocr import PaddleOCR from langchain.prompts import PromptTemplate from langchain.chains import LLMChain from passporteye import read_mrz from pydantic import BaseModel, Field from typing import Any, Optional, Dict, List from huggingface_hub import InferenceClient from langchain.llms.base import LLM HF_token = os.getenv("apiToken") model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1" hf_token = HF_token kwargs = {"max_new_tokens":500, "temperature":0.1, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True} class KwArgsModel(BaseModel): kwargs: Dict[str, Any] = Field(default_factory=dict) class CustomInferenceClient(LLM, KwArgsModel): model_name: str inference_client: InferenceClient def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None): inference_client = InferenceClient(model=model_name, token=hf_token) super().__init__( model_name=model_name, hf_token=hf_token, kwargs=kwargs, inference_client=inference_client ) def _call( self, prompt: str, stop: Optional[List[str]] = None ) -> str: if stop is not None: raise ValueError("stop kwargs are not permitted.") response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False) response = ''.join(response_gen) return response @property def _llm_type(self) -> str: return "custom" @property def _identifying_params(self) -> dict: return {"model_name": self.model_name} app = FastAPI(title="Passport Recognition API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) ocr = PaddleOCR(use_angle_cls=True, lang='en') template = """below is poorly read ocr result of a passport. OCR Result: {ocr_result} Fill the below catergories using the OCR Results. you can correct spellings and make other adujustments. Dates should be in 01-JAN-2000 format. "countryName": "", "dateOfBirth": "", "dateOfExpiry": "", "dateOfIssue": "", "documentNumber": "", "givenNames": "", "name": "", "surname": "", "mrz": "" json output: """ prompt = PromptTemplate(template=template, input_variables=["ocr_result"]) class MRZData(BaseModel): date_of_birth: str expiration_date: str type: str number: str names: str country: str check_number: str check_date_of_birth: str check_expiration_date: str check_composite: str check_personal_number: str valid_number: bool valid_date_of_birth: bool valid_expiration_date: bool valid_composite: bool valid_personal_number: bool method: str class OCRData(BaseModel): countryName: str dateOfBirth: str dateOfExpiry: str dateOfIssue: str documentNumber: str givenNames: str name: str surname: str mrz: str class ResponseData(BaseModel): documentName: str errorCode: int mrz: MRZData ocr: OCRData status: str def create_response_data(mrz, ocr_data): return ResponseData( documentName="Passport", errorCode=0, mrz=MRZData(**mrz), ocr=OCRData(**ocr_data), status="ok" ) @app.post("/recognize_passport", response_model=ResponseData, status_code=status.HTTP_201_CREATED) async def recognize_passport(image: UploadFile = File(...)): """Passport information extraction from a provided image file.""" try: image_bytes = await image.read() mrz = read_mrz(image_bytes) img_path = 'image.jpg' with open(img_path, 'wb') as f: f.write(image_bytes) result = ocr.ocr(img_path, cls=True) json_result = [] for idx in range(len(result)): res = result[idx] for line in res: coordinates, text_with_confidence = line text, confidence = text_with_confidence json_result.append({ 'coordinates': coordinates, 'text': text, 'confidence': confidence }) llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs) llm_chain = LLMChain(prompt=prompt, llm=llm) response_str = llm_chain.run(ocr_result=json_result) response_str = response_str.rstrip("") #print(response_str) ocr_data = json.loads(response_str) return create_response_data(mrz.to_dict(), ocr_data) except HTTPException as e: raise e except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal server error: {str(e)}" ) from e