File size: 7,477 Bytes
94c5764
bc391b2
94c5764
 
 
bc391b2
 
94c5764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc391b2
 
 
 
 
 
 
 
177bac3
bc391b2
 
 
94c5764
 
 
 
bc391b2
 
 
 
 
177bac3
 
bc391b2
94c5764
 
 
bc391b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94c5764
bc391b2
 
 
 
94c5764
bc391b2
 
94c5764
bc391b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177bac3
bc391b2
 
 
 
 
 
 
177bac3
bc391b2
 
 
 
 
 
 
 
 
 
 
 
 
177bac3
 
 
 
bc391b2
 
 
 
94c5764
 
 
 
 
 
bc391b2
 
 
 
 
 
 
 
94c5764
bc391b2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import time
import os
from os import getcwd, path
import importlib.metadata
from dotenv import load_dotenv


def check_additional_requirements():
    if importlib.util.find_spec("detectron2") is None:
        os.system('pip install detectron2@git+https://github.com/facebookresearch/detectron2.git')
    if importlib.util.find_spec("gradio") is not None:
        if importlib.metadata.version("gradio")!="3.4.1":
            os.system("pip uninstall -y gradio")
            os.system("pip install gradio==3.4.1")
    else:
        os.system("pip install gradio==3.4.1")
    os.system(os.environ["DD_ADDONS"])
    return


load_dotenv()
check_additional_requirements()


import deepdoctection as dd
from deepdoctection.dataflow.serialize import DataFromList
from deepdoctection.utils.settings import get_type

from dd_addons.analyzer.loader import get_loader
from dd_addons.extern.guidance import TOKEN_DEFAULT_INSTRUCTION
from dd_addons.utils.settings import register_llm_token_tag, register_string_categories_from_list
from dd_addons.extern.openai import OpenAiLmmTokenClassifier, is_api_key_valid

import gradio as gr

dd.Page.add_attribute_name("raw_json_output")
analyzer = get_loader(reset_config_file=True, config_overwrite=["OCR.USE_TESSERACT=False",
                                                                "OCR.USE_TEXTRACT=True",
                                                                "WORD_MATCHING.MAX_PARENT_ONLY=True"])

demo = gr.Blocks(css="scrollbar.css")


def process_analyzer(openai_api_key, categories_str, instruction_str, img, pdf, max_datapoints):
    if not is_api_key_valid(openai_api_key):
        return [], {}, "You have entered no or an invalid api key. Please enter a valid api key"
    categories_list = categories_str.split(",")
    if not categories_str:
        return [], {}, "You did not enter any entities. Please enter a at least one category."

    register_string_categories_from_list(categories_list, "custom_token_classes")
    custom_token_class = dd.object_types_registry.get("custom_token_classes")
    print([token_class for token_class in custom_token_class])
    register_llm_token_tag([token_class for token_class in custom_token_class])
    categories = {
        str(idx + 1): get_type(val) for idx, val in enumerate(categories_list)
    }

    gpt_token_classifier = OpenAiLmmTokenClassifier(
        model_name="gpt-3.5-turbo",
        categories=categories,
        api_key=openai_api_key,
        instruction= instruction_str if instruction_str else None,
    )
    analyzer.pipe_component_list[8].language_model = gpt_token_classifier

    if img is not None:
        image = dd.Image(file_name=str(time.time()).replace(".","") + ".png", location="")
        image.image = img[:, :, ::-1]

        df = DataFromList(lst=[image])
        df = analyzer.analyze(dataset_dataflow=df)
    elif pdf:
        df = analyzer.analyze(path=pdf.name, max_datapoints=max_datapoints)
    else:
        raise ValueError

    df.reset_state()

    json_out = {}
    dpts = []
    json_out_raw = {}

    for idx, dp in enumerate(df):
        dpts.append(dp)
        json_out[f"page_{idx}"] = dp.get_token()
        json_out_raw[f"page_{idx}"] = dp.raw_json_output

    return [dp.viz(show_cells=False, show_layouts=False, show_tables=False, show_words=True, show_token_class=True, ignore_default_token_class=True)
            for dp in dpts], json_out, json_out_raw, "No error"


with demo:
    with gr.Box():
        gr.Markdown("<h1><center>Document AI GPT</center></h1>")
        gr.Markdown("<h2 ><center>Zero or few-shot Entity Extraction powered by ChatGPT and <strong>deep</strong>doctection </center></h2>"
                    "<center>This pipeline consists of a stack of models powered for layout analysis and table recognition "
                    "to prepare a prompt for ChatGPT. </center>"
                    "<center>Be aware! The Space is still very fragile.</center><br />")
    with gr.Box():
        gr.Markdown("<h2><center>Upload a document and choose setting</center></h2>")
        with gr.Row():
            with gr.Column():
                with gr.Tab("Image upload"):
                    with gr.Column():
                        inputs = gr.Image(type='numpy', label="Original Image")
                with gr.Tab("PDF upload *"):
                    with gr.Column():
                        inputs_pdf = gr.File(label="PDF")
                    gr.Markdown("<sup>* If an image is cached in tab, remove it first</sup>")
                with gr.Box():
                    gr.Examples(
                        examples=[path.join(getcwd(), "sample_2.png")],
                        inputs = inputs)
                with gr.Box():
                    gr.Markdown("Enter your OpenAI API Key* ")
                    user_token = gr.Textbox(value='', placeholder="OpenAI API Key", type="password", show_label=False)

                    gr.Markdown("<sup>* Your API key will not be saved. However, it is always recommended to deactivate the"
                                "API key once it is entered into an unknown source</sup>")
            with gr.Column():
                with gr.Box():
                    gr.Markdown(
                        "Enter a list of comma seperated entities. Use a snake case style. Avoid special characters. "
                        "Best way is to only use `a-z` and `_`")
                    categories = gr.Textbox(value='', placeholder="mitarbeiter_anzahl", show_label=False)
                with gr.Box():
                    gr.Markdown("Optional: Enter a prompt for additional guidance. Will use the placeholder as fallback")
                    instruction = gr.Textbox(value='', placeholder=TOKEN_DEFAULT_INSTRUCTION, show_label=False)
        with gr.Row():
            max_imgs = gr.Slider(1, 3, value=1, step=1, label="Number of pages in multi page PDF",
                                 info="Will stop after 3 pages")

        with gr.Row():
            btn = gr.Button("Run model", variant="primary")

    with gr.Box():
        gr.Markdown("<h2><center>Outputs</center></h2>")
        with gr.Row():
            with gr.Column():
                with gr.Box():
                    gr.Markdown("<center><strong>Message</strong></center>")
                    msg = gr.Textbox(value='', placeholder="message", show_label=False)
            with gr.Column():
                with gr.Box():
                    gr.Markdown("<center><strong>JSON</strong></center>")
                    json = gr.JSON()
                with gr.Box():
                    gr.Markdown("<center><strong>ChatGPT output. </strong> <br />"
                                "It is possible that ChatGPT answers in an unexpected way, "
                                "such that the answer cannot be properly processed. In this case you might get"
                                "an empty JSON but you can still see the raw output.</center>")
                    json_raw = gr.JSON()
            with gr.Column():
                with gr.Box():
                    gr.Markdown("<center><strong>Layout detection</strong></center>")
                    gallery = gr.Gallery(
                        label="Output images", show_label=False, elem_id="gallery"
                    ).style(grid=2)

    btn.click(fn=process_analyzer, inputs=[user_token, categories,  instruction, inputs, inputs_pdf, max_imgs],
              outputs=[gallery, json, json_raw, msg])

demo.launch()