ethanrom commited on
Commit
47315cd
1 Parent(s): 1bc9ea6

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +43 -0
  2. arial.ttf +0 -0
  3. packages.txt +1 -0
  4. requirements.txt +5 -0
  5. server.py +176 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ import json
4
+ import os
5
+ import requests
6
+ import socket
7
+
8
+ def start_server():
9
+ os.system("uvicorn inference_server:app --port 8080 --host 0.0.0.0 --workers 2")
10
+ st.session_state['server_started'] = True
11
+
12
+ def is_port_in_use(port):
13
+ import socket
14
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
15
+ return s.connect_ex(('0.0.0.0', port)) == 0
16
+
17
+ def recognize_passport(image_path):
18
+ files = {'image': open(image_path, 'rb')}
19
+ response = requests.post("http://0.0.0.0:8080/recognize_passport", files=files)
20
+ return response.json()
21
+
22
+ if 'server_started' not in st.session_state:
23
+ st.session_state['server_started'] = False
24
+
25
+ if not st.session_state['server_started']:
26
+ start_server()
27
+
28
+ st.title('Passport Recognition Demo')
29
+
30
+ image_path = st.file_uploader("Upload Passport Image", type=["jpg", "jpeg", "png"])
31
+
32
+ if image_path is not None:
33
+ st.image(image_path, caption="Uploaded Image.", use_column_width=True)
34
+ st.write("")
35
+ st.write("Classifying...")
36
+
37
+ with open("temp_image.jpg", "wb") as f:
38
+ f.write(image_path.read())
39
+
40
+ passport_info = recognize_passport("temp_image.jpg")
41
+
42
+ st.markdown(f'## Passport Recognition Results')
43
+ st.write(json.dumps(passport_info, indent=2))
arial.ttf ADDED
Binary file (367 kB). View file
 
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ tesseract-ocr-all
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ paddlepaddle -i https://pypi.tuna.tsinghua.edu.cn/simple
2
+ fastapi
3
+ uvicorn
4
+ passporteye
5
+ paddleocr
server.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from fastapi import FastAPI, File, UploadFile, HTTPException, status
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from paddleocr import PaddleOCR
5
+ from langchain.prompts import PromptTemplate
6
+ from langchain.chains import LLMChain
7
+ from passporteye import read_mrz
8
+ from pydantic import BaseModel, Field
9
+ from typing import Any, Optional, Dict, List
10
+ from huggingface_hub import InferenceClient
11
+ from langchain.llms.base import LLM
12
+
13
+
14
+
15
+ HF_token = os.getenv("apiToken")
16
+
17
+ model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
18
+ hf_token = HF_token
19
+ kwargs = {"max_new_tokens":500, "temperature":0.1, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True}
20
+
21
+ class KwArgsModel(BaseModel):
22
+ kwargs: Dict[str, Any] = Field(default_factory=dict)
23
+
24
+ class CustomInferenceClient(LLM, KwArgsModel):
25
+ model_name: str
26
+ inference_client: InferenceClient
27
+
28
+ def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
29
+ inference_client = InferenceClient(model=model_name, token=hf_token)
30
+ super().__init__(
31
+ model_name=model_name,
32
+ hf_token=hf_token,
33
+ kwargs=kwargs,
34
+ inference_client=inference_client
35
+ )
36
+
37
+ def _call(
38
+ self,
39
+ prompt: str,
40
+ stop: Optional[List[str]] = None
41
+ ) -> str:
42
+ if stop is not None:
43
+ raise ValueError("stop kwargs are not permitted.")
44
+ response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False)
45
+ response = ''.join(response_gen)
46
+ return response
47
+
48
+ @property
49
+ def _llm_type(self) -> str:
50
+ return "custom"
51
+
52
+ @property
53
+ def _identifying_params(self) -> dict:
54
+ return {"model_name": self.model_name}
55
+
56
+ app = FastAPI(title="Passport Recognition API")
57
+
58
+ app.add_middleware(
59
+ CORSMiddleware,
60
+ allow_origins=["*"],
61
+ allow_credentials=True,
62
+ allow_methods=["*"],
63
+ allow_headers=["*"],
64
+ )
65
+
66
+ ocr = PaddleOCR(use_angle_cls=True, lang='en')
67
+ template = """below is poorly read ocr result of a passport.
68
+ OCR Result:
69
+ {ocr_result}
70
+
71
+ Fill the below catergories using the OCR Results. you can correct spellings and make other adujustments. Dates should be in 01-JAN-2000 format.
72
+
73
+ "countryName": "",
74
+ "dateOfBirth": "",
75
+ "dateOfExpiry": "",
76
+ "dateOfIssue": "",
77
+ "documentNumber": "",
78
+ "givenNames": "",
79
+ "name": "",
80
+ "surname": "",
81
+ "mrz": ""
82
+
83
+ json output:
84
+ """
85
+ prompt = PromptTemplate(template=template, input_variables=["ocr_result"])
86
+
87
+ class MRZData(BaseModel):
88
+ date_of_birth: str
89
+ expiration_date: str
90
+ type: str
91
+ number: str
92
+ names: str
93
+ country: str
94
+ check_number: str
95
+ check_date_of_birth: str
96
+ check_expiration_date: str
97
+ check_composite: str
98
+ check_personal_number: str
99
+ valid_number: bool
100
+ valid_date_of_birth: bool
101
+ valid_expiration_date: bool
102
+ valid_composite: bool
103
+ valid_personal_number: bool
104
+ method: str
105
+
106
+ class OCRData(BaseModel):
107
+ countryName: str
108
+ dateOfBirth: str
109
+ dateOfExpiry: str
110
+ dateOfIssue: str
111
+ documentNumber: str
112
+ givenNames: str
113
+ name: str
114
+ surname: str
115
+ mrz: str
116
+
117
+ class ResponseData(BaseModel):
118
+ documentName: str
119
+ errorCode: int
120
+ mrz: MRZData
121
+ ocr: OCRData
122
+ status: str
123
+
124
+
125
+ def create_response_data(mrz, ocr_data):
126
+ return ResponseData(
127
+ documentName="Passport",
128
+ errorCode=0,
129
+ mrz=MRZData(**mrz),
130
+ ocr=OCRData(**ocr_data),
131
+ status="ok"
132
+ )
133
+
134
+
135
+ @app.post("/recognize_passport", response_model=ResponseData, status_code=status.HTTP_201_CREATED)
136
+ async def recognize_passport(image: UploadFile = File(...)):
137
+ """Passport information extraction from a provided image file."""
138
+ try:
139
+ image_bytes = await image.read()
140
+ mrz = read_mrz(image_bytes)
141
+
142
+ img_path = 'image.jpg'
143
+ with open(img_path, 'wb') as f:
144
+ f.write(image_bytes)
145
+
146
+ result = ocr.ocr(img_path, cls=True)
147
+ json_result = []
148
+ for idx in range(len(result)):
149
+ res = result[idx]
150
+ for line in res:
151
+ coordinates, text_with_confidence = line
152
+ text, confidence = text_with_confidence
153
+ json_result.append({
154
+ 'coordinates': coordinates,
155
+ 'text': text,
156
+ 'confidence': confidence
157
+ })
158
+
159
+ llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs)
160
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
161
+ response_str = llm_chain.run(ocr_result=json_result)
162
+ response_str = response_str.rstrip("</s>")
163
+ #print(response_str)
164
+
165
+ ocr_data = json.loads(response_str)
166
+
167
+ return create_response_data(mrz.to_dict(), ocr_data)
168
+
169
+ except HTTPException as e:
170
+ raise e
171
+
172
+ except Exception as e:
173
+ raise HTTPException(
174
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
175
+ detail=f"Internal server error: {str(e)}"
176
+ ) from e