Harsimran19 commited on
Commit
4331eba
1 Parent(s): 6d92fcd

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +40 -0
  2. models/document_model/config.json +63 -0
  3. preprocess.py +83 -0
  4. requirements.txt +82 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytesseract
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import LayoutLMForSequenceClassification
5
+ from preprocess import apply_ocr,encode_example
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ pytesseract.pytesseract.tesseract_cmd = r"C:\\Program Files\\Tesseract-OCR\\tesseract.exe"
9
+ model = LayoutLMForSequenceClassification.from_pretrained("models/document_model")
10
+ model.to(device)
11
+ classes=['questionnaire', 'memo', 'budget', 'file_folder', 'specification', 'invoice', 'resume',
12
+ 'advertisement', 'news_article', 'email', 'scientific_publication', 'presentation',
13
+ 'letter', 'form', 'handwritten', 'scientific_report']
14
+
15
+
16
+ def predict(image):
17
+ example = apply_ocr(image)
18
+ encoded_example = encode_example(example)
19
+ input_ids = torch.tensor(encoded_example['input_ids']).unsqueeze(0)
20
+ bbox = torch.tensor(encoded_example['bbox']).unsqueeze(0)
21
+ attention_mask = torch.tensor(encoded_example['attention_mask']).unsqueeze(0)
22
+ token_type_ids = torch.tensor(encoded_example['token_type_ids']).unsqueeze(0)
23
+ model.eval()
24
+ outputs=model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
25
+ classification_results = torch.softmax(outputs.logits, dim=1).tolist()[0]
26
+ max_prob_index = classification_results.index(max(classification_results))
27
+ predicted_class = classes[max_prob_index]
28
+ return predicted_class
29
+
30
+
31
+
32
+ title="Document Image Classification"
33
+
34
+ demo = gr.Interface(
35
+ fn=predict,
36
+ inputs=gr.inputs.Image(type="pil"),
37
+ outputs=gr.outputs.Textbox(label="Predicted Class"),
38
+ title=title,
39
+ )
40
+ demo.launch(share=True)
models/document_model/config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "microsoft/layoutlm-base-uncased",
3
+ "architectures": [
4
+ "LayoutLMForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "id2label": {
11
+ "0": "LABEL_0",
12
+ "1": "LABEL_1",
13
+ "2": "LABEL_2",
14
+ "3": "LABEL_3",
15
+ "4": "LABEL_4",
16
+ "5": "LABEL_5",
17
+ "6": "LABEL_6",
18
+ "7": "LABEL_7",
19
+ "8": "LABEL_8",
20
+ "9": "LABEL_9",
21
+ "10": "LABEL_10",
22
+ "11": "LABEL_11",
23
+ "12": "LABEL_12",
24
+ "13": "LABEL_13",
25
+ "14": "LABEL_14",
26
+ "15": "LABEL_15"
27
+ },
28
+ "initializer_range": 0.02,
29
+ "intermediate_size": 3072,
30
+ "label2id": {
31
+ "LABEL_0": 0,
32
+ "LABEL_1": 1,
33
+ "LABEL_10": 10,
34
+ "LABEL_11": 11,
35
+ "LABEL_12": 12,
36
+ "LABEL_13": 13,
37
+ "LABEL_14": 14,
38
+ "LABEL_15": 15,
39
+ "LABEL_2": 2,
40
+ "LABEL_3": 3,
41
+ "LABEL_4": 4,
42
+ "LABEL_5": 5,
43
+ "LABEL_6": 6,
44
+ "LABEL_7": 7,
45
+ "LABEL_8": 8,
46
+ "LABEL_9": 9
47
+ },
48
+ "layer_norm_eps": 1e-12,
49
+ "max_2d_position_embeddings": 1024,
50
+ "max_position_embeddings": 512,
51
+ "model_type": "layoutlm",
52
+ "num_attention_heads": 12,
53
+ "num_hidden_layers": 12,
54
+ "output_past": true,
55
+ "pad_token_id": 0,
56
+ "position_embedding_type": "absolute",
57
+ "problem_type": "single_label_classification",
58
+ "torch_dtype": "float32",
59
+ "transformers_version": "4.30.2",
60
+ "type_vocab_size": 2,
61
+ "use_cache": true,
62
+ "vocab_size": 30522
63
+ }
preprocess.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytesseract
2
+ from PIL import Image
3
+ import numpy as np
4
+ from transformers import LayoutLMTokenizer
5
+
6
+
7
+ pytesseract.pytesseract.tesseract_cmd = r"C:\\Program Files\\Tesseract-OCR\\tesseract.exe"
8
+ tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
9
+ def normalize_box(box, width, height):
10
+ return [
11
+ int(1000 * (box[0] / width)),
12
+ int(1000 * (box[1] / height)),
13
+ int(1000 * (box[2] / width)),
14
+ int(1000 * (box[3] / height)),
15
+ ]
16
+
17
+ def apply_ocr(image):
18
+ # get the image
19
+ # image = Image.open(example['image_path'])
20
+
21
+ width, height = image.size
22
+ example={}
23
+ # apply ocr to the image
24
+ ocr_df = pytesseract.image_to_data(image, output_type='data.frame')
25
+ float_cols = ocr_df.select_dtypes('float').columns
26
+ ocr_df = ocr_df.dropna().reset_index(drop=True)
27
+ ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
28
+ ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
29
+ ocr_df = ocr_df.dropna().reset_index(drop=True)
30
+
31
+ # get the words and actual (unnormalized) bounding boxes
32
+ #words = [word for word in ocr_df.text if str(word) != 'nan'])
33
+ words = list(ocr_df.text)
34
+ words = [str(w) for w in words]
35
+ coordinates = ocr_df[['left', 'top', 'width', 'height']]
36
+ actual_boxes = []
37
+ for idx, row in coordinates.iterrows():
38
+ x, y, w, h = tuple(row) # the row comes in (left, top, width, height) format
39
+ actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+width, top+height) to get the actual box
40
+ actual_boxes.append(actual_box)
41
+
42
+ # normalize the bounding boxes
43
+ boxes = []
44
+ for box in actual_boxes:
45
+ boxes.append(normalize_box(box, width, height))
46
+
47
+ # add as extra columns
48
+ assert len(words) == len(boxes)
49
+ example['words'] = words
50
+ example['bbox'] = boxes
51
+ return example
52
+ def encode_example(example, max_seq_length=512, pad_token_box=[0, 0, 0, 0]):
53
+ words = example['words']
54
+ normalized_word_boxes = example['bbox']
55
+
56
+ assert len(words) == len(normalized_word_boxes)
57
+
58
+ token_boxes = []
59
+ for word, box in zip(words, normalized_word_boxes):
60
+ word_tokens = tokenizer.tokenize(word)
61
+ token_boxes.extend([box] * len(word_tokens))
62
+
63
+ # Truncation of token_boxes
64
+ special_tokens_count = 2
65
+ if len(token_boxes) > max_seq_length - special_tokens_count:
66
+ token_boxes = token_boxes[: (max_seq_length - special_tokens_count)]
67
+
68
+ # add bounding boxes of cls + sep tokens
69
+ token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
70
+
71
+ encoding = tokenizer(' '.join(words), padding='max_length', truncation=True)
72
+ # Padding of token_boxes up the bounding boxes to the sequence length.
73
+ input_ids = tokenizer(' '.join(words), truncation=True)["input_ids"]
74
+ padding_length = max_seq_length - len(input_ids)
75
+ token_boxes += [pad_token_box] * padding_length
76
+ encoding['bbox'] = token_boxes
77
+
78
+ assert len(encoding['input_ids']) == max_seq_length
79
+ assert len(encoding['attention_mask']) == max_seq_length
80
+ assert len(encoding['token_type_ids']) == max_seq_length
81
+ assert len(encoding['bbox']) == max_seq_length
82
+
83
+ return encoding
requirements.txt ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==5.0.1
5
+ annotated-types==0.5.0
6
+ anyio==3.7.1
7
+ async-timeout==4.0.2
8
+ attrs==23.1.0
9
+ certifi==2023.5.7
10
+ charset-normalizer==3.2.0
11
+ click==8.1.6
12
+ colorama==0.4.6
13
+ contourpy==1.1.0
14
+ cycler==0.11.0
15
+ datasets==2.13.1
16
+ dill==0.3.6
17
+ exceptiongroup==1.1.2
18
+ fastapi==0.100.0
19
+ ffmpy==0.3.1
20
+ filelock==3.12.2
21
+ fonttools==4.41.0
22
+ frozenlist==1.4.0
23
+ fsspec==2023.6.0
24
+ gradio==3.37.0
25
+ gradio_client==0.2.10
26
+ h11==0.14.0
27
+ httpcore==0.17.3
28
+ httpx==0.24.1
29
+ huggingface-hub==0.16.4
30
+ idna==3.4
31
+ Jinja2==3.1.2
32
+ jsonschema==4.18.4
33
+ jsonschema-specifications==2023.7.1
34
+ kiwisolver==1.4.4
35
+ linkify-it-py==2.0.2
36
+ markdown-it-py==2.2.0
37
+ MarkupSafe==2.1.3
38
+ matplotlib==3.7.2
39
+ mdit-py-plugins==0.3.3
40
+ mdurl==0.1.2
41
+ mpmath==1.3.0
42
+ multidict==6.0.4
43
+ multiprocess==0.70.14
44
+ networkx==3.1
45
+ numpy==1.25.1
46
+ orjson==3.9.2
47
+ packaging==23.1
48
+ pandas==2.0.3
49
+ Pillow==10.0.0
50
+ pyarrow==12.0.1
51
+ pydantic==2.0.3
52
+ pydantic_core==2.3.0
53
+ pydub==0.25.1
54
+ pyparsing==3.0.9
55
+ pytesseract==0.3.10
56
+ python-dateutil==2.8.2
57
+ python-multipart==0.0.6
58
+ pytz==2023.3
59
+ PyYAML==6.0.1
60
+ referencing==0.30.0
61
+ regex==2023.6.3
62
+ requests==2.31.0
63
+ rpds-py==0.9.2
64
+ safetensors==0.3.1
65
+ semantic-version==2.10.0
66
+ six==1.16.0
67
+ sniffio==1.3.0
68
+ starlette==0.27.0
69
+ sympy==1.12
70
+ tokenizers==0.13.3
71
+ toolz==0.12.0
72
+ torch==2.0.1
73
+ tqdm==4.65.0
74
+ transformers==4.31.0
75
+ typing_extensions==4.7.1
76
+ tzdata==2023.3
77
+ uc-micro-py==1.0.2
78
+ urllib3==2.0.3
79
+ uvicorn==0.23.1
80
+ websockets==11.0.3
81
+ xxhash==3.2.0
82
+ yarl==1.9.2