Spaces:
Paused
Paused
File size: 4,994 Bytes
47315cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import json
from fastapi import FastAPI, File, UploadFile, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from paddleocr import PaddleOCR
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from passporteye import read_mrz
from pydantic import BaseModel, Field
from typing import Any, Optional, Dict, List
from huggingface_hub import InferenceClient
from langchain.llms.base import LLM
HF_token = os.getenv("apiToken")
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
hf_token = HF_token
kwargs = {"max_new_tokens":500, "temperature":0.1, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True}
class KwArgsModel(BaseModel):
kwargs: Dict[str, Any] = Field(default_factory=dict)
class CustomInferenceClient(LLM, KwArgsModel):
model_name: str
inference_client: InferenceClient
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
inference_client = InferenceClient(model=model_name, token=hf_token)
super().__init__(
model_name=model_name,
hf_token=hf_token,
kwargs=kwargs,
inference_client=inference_client
)
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False)
response = ''.join(response_gen)
return response
@property
def _llm_type(self) -> str:
return "custom"
@property
def _identifying_params(self) -> dict:
return {"model_name": self.model_name}
app = FastAPI(title="Passport Recognition API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
ocr = PaddleOCR(use_angle_cls=True, lang='en')
template = """below is poorly read ocr result of a passport.
OCR Result:
{ocr_result}
Fill the below catergories using the OCR Results. you can correct spellings and make other adujustments. Dates should be in 01-JAN-2000 format.
"countryName": "",
"dateOfBirth": "",
"dateOfExpiry": "",
"dateOfIssue": "",
"documentNumber": "",
"givenNames": "",
"name": "",
"surname": "",
"mrz": ""
json output:
"""
prompt = PromptTemplate(template=template, input_variables=["ocr_result"])
class MRZData(BaseModel):
date_of_birth: str
expiration_date: str
type: str
number: str
names: str
country: str
check_number: str
check_date_of_birth: str
check_expiration_date: str
check_composite: str
check_personal_number: str
valid_number: bool
valid_date_of_birth: bool
valid_expiration_date: bool
valid_composite: bool
valid_personal_number: bool
method: str
class OCRData(BaseModel):
countryName: str
dateOfBirth: str
dateOfExpiry: str
dateOfIssue: str
documentNumber: str
givenNames: str
name: str
surname: str
mrz: str
class ResponseData(BaseModel):
documentName: str
errorCode: int
mrz: MRZData
ocr: OCRData
status: str
def create_response_data(mrz, ocr_data):
return ResponseData(
documentName="Passport",
errorCode=0,
mrz=MRZData(**mrz),
ocr=OCRData(**ocr_data),
status="ok"
)
@app.post("/recognize_passport", response_model=ResponseData, status_code=status.HTTP_201_CREATED)
async def recognize_passport(image: UploadFile = File(...)):
"""Passport information extraction from a provided image file."""
try:
image_bytes = await image.read()
mrz = read_mrz(image_bytes)
img_path = 'image.jpg'
with open(img_path, 'wb') as f:
f.write(image_bytes)
result = ocr.ocr(img_path, cls=True)
json_result = []
for idx in range(len(result)):
res = result[idx]
for line in res:
coordinates, text_with_confidence = line
text, confidence = text_with_confidence
json_result.append({
'coordinates': coordinates,
'text': text,
'confidence': confidence
})
llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs)
llm_chain = LLMChain(prompt=prompt, llm=llm)
response_str = llm_chain.run(ocr_result=json_result)
response_str = response_str.rstrip("</s>")
#print(response_str)
ocr_data = json.loads(response_str)
return create_response_data(mrz.to_dict(), ocr_data)
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Internal server error: {str(e)}"
) from e |