Annas Dev commited on
Commit
8cb5b3c
1 Parent(s): d0dae0b

add basic files

Browse files
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .env
2
- venv
 
 
1
  .env
2
+ venv
3
+ tmp
app.py CHANGED
@@ -1,15 +1,45 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
 
 
 
 
3
 
4
  def get_model():
5
- return ''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
 
 
 
7
 
8
- def run(img):
9
- print('running...')
10
- return img
11
 
12
 
13
  gr.Markdown('Upload Foto Wajah Kamu (Pastikan hanya terdapat SATU wajah pada)')
14
- iface = gr.Interface(fn=run, inputs=["image"], outputs="image")
15
  iface.launch()
 
1
  import gradio as gr
2
+ from dotenv import load_dotenv
3
+ import os
4
+ import torch
5
+ import warnings
6
+ from PIL import Image
7
+ from util import file_helper
8
+ from inference.ocr import prepare_batch_for_inference
9
+ from inference.inference_handler import handle
10
 
11
+ os.system('sudo apt install -y -q tesseract-ocr')
12
+ os.system('sudo apt install -y -q libtesseract-dev')
13
+
14
+ load_dotenv()
15
 
16
  def get_model():
17
+ model_dir = "tmp"
18
+ model_filename= 'receipt.pth'
19
+ full_path = os.path.join(model_dir, model_filename)
20
+ if os.path.isfile(full_path):
21
+ return full_path
22
+
23
+ return file_helper.download_gdrive(os.getenv('MODEL_ID'), model_dir, model_filename)
24
+
25
+ def run_inference(model_path, images_path):
26
+ try:
27
+ inference_batch = prepare_batch_for_inference(images_path)
28
+ context = {"model_dir": model_path}
29
+ print('handle....')
30
+ handle(inference_batch,context)
31
+ except Exception as err:
32
+ print('err...', err)
33
+
34
 
35
+ def run(img_path):
36
+ print('img path: ', img_path)
37
+ model_path = get_model()
38
+ run_inference(model_path, [img_path])
39
 
40
+ return Image.open(img_path)
 
 
41
 
42
 
43
  gr.Markdown('Upload Foto Wajah Kamu (Pastikan hanya terdapat SATU wajah pada)')
44
+ iface = gr.Interface(fn=run, inputs=gr.Image(type="filepath"), outputs="image")
45
  iface.launch()
inference/__pycache__/annotate_image.cpython-38.pyc ADDED
Binary file (1.85 kB). View file
 
inference/__pycache__/inference_handler.cpython-38.pyc ADDED
Binary file (6.55 kB). View file
 
inference/__pycache__/ocr.cpython-38.pyc ADDED
Binary file (2.65 kB). View file
 
inference/__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.4 kB). View file
 
inference/annotate_image.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ from .utils import image_label_2_color
4
+
5
+
6
+ def get_flattened_output(docs):
7
+ flattened_output = []
8
+ annotation_key = 'output'
9
+ for doc in docs:
10
+ flattened_output_item = {annotation_key: []}
11
+ doc_annotation = doc[annotation_key]
12
+ for i, span in enumerate(doc_annotation):
13
+ if len(span['words']) > 1:
14
+ for span_chunk in span['words']:
15
+ flattened_output_item[annotation_key].append(
16
+ {
17
+ 'label': span['label'],
18
+ 'text': span_chunk['text'],
19
+ 'words': [span_chunk]
20
+ }
21
+ )
22
+ else:
23
+ flattened_output_item[annotation_key].append(span)
24
+ flattened_output.append(flattened_output_item)
25
+ return flattened_output
26
+
27
+
28
+ def annotate_image(image_path, annotation_object):
29
+ img = None
30
+ image = Image.open(image_path).convert('RGBA')
31
+ tmp = image.copy()
32
+ label2color = image_label_2_color(annotation_object)
33
+ overlay = Image.new('RGBA', tmp.size, (0, 0, 0)+(0,))
34
+ draw = ImageDraw.Draw(overlay)
35
+ font = ImageFont.load_default()
36
+
37
+ predictions = [span['label'] for span in annotation_object['output']]
38
+ boxes = [span['words'][0]['box'] for span in annotation_object['output']]
39
+ for prediction, box in zip(predictions, boxes):
40
+ draw.rectangle(box, outline=label2color[prediction],
41
+ width=3, fill=label2color[prediction]+(int(255*0.33),))
42
+ draw.text((box[0] + 10, box[1] - 10), text=prediction,
43
+ fill=label2color[prediction], font=font)
44
+
45
+ img = Image.alpha_composite(tmp, overlay)
46
+ img = img.convert("RGB")
47
+
48
+ image_name = os.path.basename(image_path)
49
+ image_name = image_name[:image_name.find('.')]
50
+ img.save(f'/content/{image_name}_inference.jpg')
inference/inference_handler.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import load_model,load_processor,normalize_box,compare_boxes,adjacent
2
+ from .annotate_image import get_flattened_output,annotate_image
3
+ from PIL import Image,ImageDraw, ImageFont
4
+ import logging
5
+ import torch
6
+ import json
7
+
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class ModelHandler(object):
13
+ """
14
+ A base Model handler implementation.
15
+ """
16
+
17
+ def __init__(self):
18
+ self.model = None
19
+ self.model_dir = None
20
+ self.device = 'cpu'
21
+ self.error = None
22
+ # self._context = None
23
+ # self._batch_size = 0
24
+ self.initialized = False
25
+ self._raw_input_data = None
26
+ self._processed_data = None
27
+ self._images_size = None
28
+
29
+ def initialize(self, context):
30
+ """
31
+ Initialize model. This will be called during model loading time
32
+ :param context: Initial context contains model server system properties.
33
+ :return:
34
+ """
35
+ logger.info("Loading transformer model")
36
+
37
+ self._context = context
38
+ properties = self._context
39
+ # self._batch_size = properties["batch_size"] or 1
40
+ self.model_dir = properties.get("model_dir")
41
+ self.model = self.load(self.model_dir)
42
+ self.initialized = True
43
+
44
+ def preprocess(self, batch):
45
+ """
46
+ Transform raw input into model input data.
47
+ :param batch: list of raw requests, should match batch size
48
+ :return: list of preprocessed model input data
49
+ """
50
+ # Take the input data and pre-process it make it inference ready
51
+ # assert self._batch_size == len(batch), "Invalid input batch size: {}".format(len(batch))
52
+ inference_dict = batch
53
+ self._raw_input_data = inference_dict
54
+ processor = load_processor()
55
+ images = [Image.open(path).convert("RGB")
56
+ for path in inference_dict['image_path']]
57
+ self._images_size = [img.size for img in images]
58
+ words = inference_dict['words']
59
+ boxes = [[normalize_box(box, images[i].size[0], images[i].size[1])
60
+ for box in doc] for i, doc in enumerate(inference_dict['bboxes'])]
61
+ encoded_inputs = processor(
62
+ images, words, boxes=boxes, return_tensors="pt", padding="max_length", truncation=True)
63
+ self._processed_data = encoded_inputs
64
+ return encoded_inputs
65
+
66
+ def load(self, model_dir):
67
+ """The load handler is responsible for loading the hunggingface transformer model.
68
+ Returns:
69
+ hf_pipeline (Pipeline): A Hugging Face Transformer pipeline.
70
+ """
71
+ # TODO model dir should be microsoft/layoutlmv2-base-uncased
72
+ model = load_model(model_dir)
73
+ return model
74
+
75
+ def inference(self, model_input):
76
+ """
77
+ Internal inference methods
78
+ :param model_input: transformed model input data
79
+ :return: list of inference output in NDArray
80
+ """
81
+ # TODO load the model state_dict before running the inference
82
+ # Do some inference call to engine here and return output
83
+ with torch.no_grad():
84
+ inference_outputs = self.model(**model_input)
85
+ predictions = inference_outputs.logits.argmax(-1).tolist()
86
+ results = []
87
+ for i in range(len(predictions)):
88
+ tmp = dict()
89
+ tmp[f'output_{i}'] = predictions[i]
90
+ results.append(tmp)
91
+
92
+ return [results]
93
+
94
+ def postprocess(self, inference_output):
95
+ docs = []
96
+ k = 0
97
+ for page, doc_words in enumerate(self._raw_input_data['words']):
98
+ doc_list = []
99
+ width, height = self._images_size[page]
100
+ for i, doc_word in enumerate(doc_words, start=0):
101
+ word_tagging = None
102
+ word_labels = []
103
+ word = dict()
104
+ word['id'] = k
105
+ k += 1
106
+ word['text'] = doc_word
107
+ word['pageNum'] = page + 1
108
+ word['box'] = self._raw_input_data['bboxes'][page][i]
109
+ _normalized_box = normalize_box(
110
+ self._raw_input_data['bboxes'][page][i], width, height)
111
+ for j, box in enumerate(self._processed_data['bbox'].tolist()[page]):
112
+ if compare_boxes(box, _normalized_box):
113
+ if self.model.config.id2label[inference_output[0][page][f'output_{page}'][j]] != 'O':
114
+ word_labels.append(
115
+ self.model.config.id2label[inference_output[0][page][f'output_{page}'][j]][2:])
116
+ else:
117
+ word_labels.append('other')
118
+ if word_labels != []:
119
+ word_tagging = word_labels[0] if word_labels[0] != 'other' else word_labels[-1]
120
+ else:
121
+ word_tagging = 'other'
122
+ word['label'] = word_tagging
123
+ word['pageSize'] = {'width': width, 'height': height}
124
+ if word['label'] != 'other':
125
+ doc_list.append(word)
126
+ spans = []
127
+ def adjacents(entity): return [
128
+ adj for adj in doc_list if adjacent(entity, adj)]
129
+ output_test_tmp = doc_list[:]
130
+ for entity in doc_list:
131
+ if adjacents(entity) == []:
132
+ spans.append([entity])
133
+ output_test_tmp.remove(entity)
134
+
135
+ while output_test_tmp != []:
136
+ span = [output_test_tmp[0]]
137
+ output_test_tmp = output_test_tmp[1:]
138
+ while output_test_tmp != [] and adjacent(span[-1], output_test_tmp[0]):
139
+ span.append(output_test_tmp[0])
140
+ output_test_tmp.remove(output_test_tmp[0])
141
+ spans.append(span)
142
+
143
+ output_spans = []
144
+ for span in spans:
145
+ if len(span) == 1:
146
+ output_span = {"text": span[0]['text'],
147
+ "label": span[0]['label'],
148
+ "words": [{
149
+ 'id': span[0]['id'],
150
+ 'box': span[0]['box'],
151
+ 'text': span[0]['text']
152
+ }],
153
+ }
154
+ else:
155
+ output_span = {"text": ' '.join([entity['text'] for entity in span]),
156
+ "label": span[0]['label'],
157
+ "words": [{
158
+ 'id': entity['id'],
159
+ 'box': entity['box'],
160
+ 'text': entity['text']
161
+ } for entity in span]
162
+
163
+ }
164
+ output_spans.append(output_span)
165
+ docs.append({f'output': output_spans})
166
+ return [json.dumps(docs, ensure_ascii=False)]
167
+
168
+ def handle(self, data, context):
169
+ """
170
+ Call preprocess, inference and post-process functions
171
+ :param data: input data
172
+ :param context: mms context
173
+ """
174
+ model_input = self.preprocess(data)
175
+ model_out = self.inference(model_input)
176
+ inference_out = self.postprocess(model_out)[0]
177
+ with open('LayoutlMV3InferenceOutput.json', 'w') as inf_out:
178
+ inf_out.write(inference_out)
179
+ inference_out_list = json.loads(inference_out)
180
+ flattened_output_list = get_flattened_output(inference_out_list)
181
+ for i, flattened_output in enumerate(flattened_output_list):
182
+ annotate_image(data['image_path'][i], flattened_output)
183
+
184
+
185
+
186
+ _service = ModelHandler()
187
+
188
+
189
+ def handle(data, context):
190
+ if not _service.initialized:
191
+ _service.initialize(context)
192
+
193
+ if data is None:
194
+ return None
195
+
196
+ return _service.handle(data, context)
inference/ocr.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+
4
+ def is_image(img_path):
5
+ ext = os.path.splitext(img_path)[1]
6
+ result = ext == ".jpg" or ext == ".png"
7
+ if not result: print('NOT IMAGE: ', img_path)
8
+ return result
9
+
10
+ def run_tesseract_on_image(image_path): # -> tsv output path
11
+ print('--- run tesseract on ', image_path)
12
+ image_name = os.path.basename(image_path)
13
+ image_name = image_name[:image_name.find('.')]
14
+ error_code = os.system(f'''
15
+ tesseract "{image_path}" "/content/{image_name}" -l eng tsv
16
+ ''')
17
+ if not error_code:
18
+ return f"/content/{image_name}.tsv"
19
+ else:
20
+ raise ValueError('Tesseract OCR Error please verify image format PNG,JPG,JPEG')
21
+
22
+
23
+ def clean_tesseract_output(tsv_output_path):
24
+ print('clean tesseract output for: ', tsv_output_path)
25
+ ocr_df = pd.read_csv(tsv_output_path, sep='\t')
26
+ ocr_df = ocr_df.dropna()
27
+ ocr_df = ocr_df.drop(ocr_df[ocr_df.text.str.strip() == ''].index)
28
+ text_output = ' '.join(ocr_df.text.tolist())
29
+ words = []
30
+ for index, row in ocr_df.iterrows():
31
+ word = {}
32
+ origin_box = [row['left'], row['top'], row['left'] +
33
+ row['width'], row['top']+row['height']]
34
+ word['word_text'] = row['text']
35
+ word['word_box'] = origin_box
36
+ words.append(word)
37
+ return words
38
+
39
+
40
+ def prepare_batch_for_inference(image_paths):
41
+ # tesseract_outputs is a list of paths
42
+ inference_batch = dict()
43
+ tesseract_outputs = [run_tesseract_on_image(
44
+ image_path) for image_path in image_paths if (is_image(image_path))]
45
+
46
+ print('tesseract has run on all images...')
47
+ # clean_outputs is a list of lists
48
+ clean_outputs = [clean_tesseract_output(
49
+ tsv_path) for tsv_path in tesseract_outputs]
50
+ word_lists = [[word['word_text'] for word in clean_output]
51
+ for clean_output in clean_outputs]
52
+ boxes_lists = [[word['word_box'] for word in clean_output]
53
+ for clean_output in clean_outputs]
54
+ inference_batch = {
55
+ "image_path": image_paths,
56
+ "bboxes": boxes_lists,
57
+ "words": word_lists
58
+ }
59
+ print('inference_batch:', inference_batch)
60
+ return inference_batch
inference/utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from transformers import AutoModelForTokenClassification, AutoProcessor
3
+
4
+ def normalize_box(bbox, width, height):
5
+ return [
6
+ int(bbox[0]*(1000/width)),
7
+ int(bbox[1]*(1000/height)),
8
+ int(bbox[2]*(1000/width)),
9
+ int(bbox[3]*(1000/height)),
10
+ ]
11
+
12
+ def compare_boxes(b1, b2):
13
+ b1 = np.array([c for c in b1])
14
+ b2 = np.array([c for c in b2])
15
+ equal = np.array_equal(b1, b2)
16
+ return equal
17
+
18
+ def unnormalize_box(bbox, width, height):
19
+ return [
20
+ width * (bbox[0] / 1000),
21
+ height * (bbox[1] / 1000),
22
+ width * (bbox[2] / 1000),
23
+ height * (bbox[3] / 1000),
24
+ ]
25
+
26
+ def adjacent(w1, w2):
27
+ if w1['label'] == w2['label'] and abs(w1['id'] - w2['id']) == 1:
28
+ return True
29
+ return False
30
+
31
+ def random_color():
32
+ return np.random.randint(0, 255, 3)
33
+
34
+ def image_label_2_color(annotation):
35
+ if 'output' in annotation.keys():
36
+ image_labels = set([span['label'] for span in annotation['output']])
37
+ label2color = {f'{label}': (random_color()[0], random_color()[
38
+ 1], random_color()[2]) for label in image_labels}
39
+ return label2color
40
+ else:
41
+ raise ValueError('please use "output" as annotation key')
42
+
43
+ def load_model(model_path):
44
+ model = AutoModelForTokenClassification.from_pretrained(model_path)
45
+ return model
46
+
47
+ def load_processor():
48
+ processor = AutoProcessor.from_pretrained(
49
+ "microsoft/layoutlmv3-base", apply_ocr=False)
50
+ return processor
requirements.txt CHANGED
@@ -3,39 +3,76 @@ aiosignal==1.3.1
3
  altair==4.2.0
4
  anyio==3.6.2
5
  async-timeout==4.0.2
 
 
 
 
6
  click==8.1.3
7
  contourpy==1.0.6
8
  cycler==0.11.0
 
9
  fastapi==0.88.0
10
  ffmpy==0.3.0
 
11
  fonttools==4.38.0
12
  frozenlist==1.3.3
13
  fsspec==2022.11.0
 
14
  gradio==3.14.0
15
  h11==0.14.0
16
  httpcore==0.16.2
17
  httpx==0.23.1
 
 
 
 
 
18
  kiwisolver==1.4.4
19
  linkify-it-py==1.0.3
20
  markdown-it-py==2.1.0
 
21
  matplotlib==3.6.2
22
  mdit-py-plugins==0.3.3
23
  mdurl==0.1.2
24
  multidict==6.0.3
25
  numpy==1.23.5
 
 
 
 
26
  orjson==3.8.3
 
27
  pandas==1.5.2
28
  Pillow==9.3.0
 
29
  pycryptodome==3.16.0
30
  pydantic==1.10.2
31
  pydub==0.25.1
 
 
 
 
 
 
32
  python-multipart==0.0.5
33
  pytz==2022.7
 
 
 
34
  rfc3986==1.5.0
 
35
  sniffio==1.3.0
 
36
  starlette==0.22.0
 
37
  toolz==0.12.0
 
 
 
 
38
  uc-micro-py==1.0.1
 
39
  uvicorn==0.20.0
40
  websockets==10.4
41
  yarl==1.8.2
 
 
3
  altair==4.2.0
4
  anyio==3.6.2
5
  async-timeout==4.0.2
6
+ attrs==22.1.0
7
+ beautifulsoup4==4.11.1
8
+ certifi==2022.12.7
9
+ charset-normalizer==2.1.1
10
  click==8.1.3
11
  contourpy==1.0.6
12
  cycler==0.11.0
13
+ entrypoints==0.4
14
  fastapi==0.88.0
15
  ffmpy==0.3.0
16
+ filelock==3.8.2
17
  fonttools==4.38.0
18
  frozenlist==1.3.3
19
  fsspec==2022.11.0
20
+ gdown==4.6.0
21
  gradio==3.14.0
22
  h11==0.14.0
23
  httpcore==0.16.2
24
  httpx==0.23.1
25
+ huggingface-hub==0.11.1
26
+ idna==3.4
27
+ importlib-resources==5.10.1
28
+ Jinja2==3.1.2
29
+ jsonschema==4.17.3
30
  kiwisolver==1.4.4
31
  linkify-it-py==1.0.3
32
  markdown-it-py==2.1.0
33
+ MarkupSafe==2.1.1
34
  matplotlib==3.6.2
35
  mdit-py-plugins==0.3.3
36
  mdurl==0.1.2
37
  multidict==6.0.3
38
  numpy==1.23.5
39
+ nvidia-cublas-cu11==11.10.3.66
40
+ nvidia-cuda-nvrtc-cu11==11.7.99
41
+ nvidia-cuda-runtime-cu11==11.7.99
42
+ nvidia-cudnn-cu11==8.5.0.96
43
  orjson==3.8.3
44
+ packaging==22.0
45
  pandas==1.5.2
46
  Pillow==9.3.0
47
+ pkgutil_resolve_name==1.3.10
48
  pycryptodome==3.16.0
49
  pydantic==1.10.2
50
  pydub==0.25.1
51
+ pyparsing==3.0.9
52
+ pyrsistent==0.19.2
53
+ PySocks==1.7.1
54
+ pytesseract==0.3.10
55
+ python-dateutil==2.8.2
56
+ python-dotenv==0.21.0
57
  python-multipart==0.0.5
58
  pytz==2022.7
59
+ PyYAML==6.0
60
+ regex==2022.10.31
61
+ requests==2.28.1
62
  rfc3986==1.5.0
63
+ six==1.16.0
64
  sniffio==1.3.0
65
+ soupsieve==2.3.2.post1
66
  starlette==0.22.0
67
+ tokenizers==0.13.2
68
  toolz==0.12.0
69
+ torch==1.13.1
70
+ tqdm==4.64.1
71
+ transformers @ git+https://github.com/huggingface/transformers.git@7032e0203262ebb2ebf55da8d2e01f873973e835
72
+ typing_extensions==4.4.0
73
  uc-micro-py==1.0.1
74
+ urllib3==1.26.13
75
  uvicorn==0.20.0
76
  websockets==10.4
77
  yarl==1.8.2
78
+ zipp==3.11.0
util/__pycache__/file_helper.cpython-38.pyc ADDED
Binary file (521 Bytes). View file
 
util/file_helper.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gdown
2
+ import os
3
+
4
+
5
+ def download_gdrive(id, dir = ".", filename = None):
6
+ print('download...')
7
+ tmp_filename = gdown.download(id=id, quiet=True)
8
+ if filename is None:
9
+ filename = tmp_filename
10
+
11
+ file_path = f'{dir}/{filename}'
12
+
13
+ if os.path.isdir(dir) == False: os.mkdir(dir)
14
+ os.replace(tmp_filename, file_path)
15
+ return file_path