GLM-4-DOC / app.py
Jimhugging's picture
Update app.py
93f18de verified
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import pymupdf
import docx
from pptx import Presentation
MODEL_LIST = ["THUDM/glm-4v-9b"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = os.environ.get("MODEL_ID")
MODEL_NAME = MODEL_ID.split("/")[-1]
TITLE = "<h1>Multimodal Model for Complex Doc Extraction</h1>"
DESCRIPTION = f"""
<center>
<p>๐Ÿ˜Š A Demo For Complex Doc Extraction via GLM4.
<br>
๐Ÿš€ MODEL NOW: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a>
<br>
โœจ Important: Do not upload any sensitive documents.
<br>
๐Ÿ™‡โ€โ™‚๏ธ May be rebuilding from time to time.</p>
</center>"""
CSS = """
h1 {
text-align: center;
display: block;
}
"""
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(0)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model.eval()
def extract_text(path):
return open(path, 'r').read()
def extract_pdf(path):
doc = pymupdf.open(path)
text = ""
for page in doc:
text += page.get_text()
return text
def extract_docx(path):
doc = docx.Document(path)
data = []
for paragraph in doc.paragraphs:
data.append(paragraph.text)
content = '\n\n'.join(data)
return content
def extract_pptx(path):
prs = Presentation(path)
text = ""
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text += shape.text + "\n"
return text
def mode_load(path):
choice = ""
file_type = path.split(".")[-1]
print(file_type)
if file_type in ["pdf", "txt", "py", "docx", "pptx", "json", "cpp", "md"]:
if file_type.endswith("pdf"):
content = extract_pdf(path)
elif file_type.endswith("docx"):
content = extract_docx(path)
elif file_type.endswith("pptx"):
content = extract_pptx(path)
else:
content = extract_text(path)
choice = "doc"
print(content[:100])
return choice, content[:5000]
elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
content = Image.open(path).convert('RGB')
choice = "image"
return choice, content
else:
raise gr.Error("Oops, unsupported files.")
@spaces.GPU()
def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
print(f'message is - {message}')
print(f'history is - {history}')
conversation = []
prompt_files = []
if message["files"]:
choice, contents = mode_load(message["files"][-1])
if choice == "image":
conversation.append({"role": "user", "image": contents, "content": message['text']})
elif choice == "doc":
format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
conversation.append({"role": "user", "content": format_msg})
else:
if len(history) == 0:
#raise gr.Error("Please upload an image first.")
contents = None
conversation.append({"role": "user", "content": message['text']})
else:
#image = Image.open(history[0][0][0])
for prompt, answer in history:
if answer is None:
prompt_files.append(prompt[0])
conversation.extend([{"role": "user", "content": ""},{"role": "assistant", "content": ""}])
else:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
choice, contents = mode_load(prompt_files[-1])
if choice == "image":
conversation.append({"role": "user", "image": contents, "content": message['text']})
elif choice == "doc":
format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
conversation.append({"role": "user", "content": format_msg})
print(f"Conversation is -\n{conversation}")
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
max_length=max_length,
streamer=streamer,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=penalty,
eos_token_id=[151329, 151336, 151338],
)
gen_kwargs = {**input_ids, **generate_kwargs}
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot()
chat_input = gr.MultimodalTextbox(
interactive=True,
placeholder="Enter message or upload a file one time...",
show_label=False,
)
prompt_extraction = '''
Please extract the Key information from the given document, including:
1 Date (DD/MM/YY)
2 Name(s) of Account Holder(s)
3 ID Document No.
4 ID Document Type
5 (Checked?)Are you the existing Personal e-Banking user?(expected result: yes/no)
6 (Checked?)Delete all Registered Bill Account Numbers for bill payment via any accounts under the above-mentioned ID Document or Business Registration No...(expected result: yes/no)
7 (Checked?)Add the following Bill Account Numbers for bill payment via any accounts under the above-mentioned ID Document or Business Registration No...(expected result: yes/no)
8 Signature(s) of Account Holder(s)
9 Account No.
Please Notice:
Return in JSON format, plain(no nested structure)
9 keys as listed in Objective section, don't ignore any item from the list even the value is None.
Please note:
Your accuracy of detecting checkboxes is relatively low, and unchecked boxes are often mistakenly identified as checked/true (resulting in many false positives).
Therefore, for all fields marked "(Checked?)", please be extra cautious and carefully examine the image before extracting the information, and if you are unsure, please default to false.
'''
prompt_signature = '''
I need your help to compare and score the similarity of two signatures. You should evaluate the signatures based on several dimensions and calculate a final similarity score. Here are the detailed instructions and dimensions for the comparison:
1. **General Shape and Flow**:
- Compare the overall shape and flow of the two signatures.
- Score from 0 to 10, where 0 means completely different and 10 means identical.
2. **Consistency of Loops and Strokes**:
- Evaluate the presence and consistency of loops and strokes in the signatures.
- Score from 0 to 10 based on the similarity of these features.
3. **Signature Characteristics**:
- Compare specific characteristics such as dots, dashes, and unique flourishes.
- Score from 0 to 10 based on the presence and similarity of these unique features.
4. **Stroke Pressure and Line Thickness**:
- Analyze the pressure and thickness of the lines in the signatures.
- Score from 0 to 10 based on how similar the pressure and thickness are between the two signatures.
5. **Angle and Slope**:
- Evaluate the angle and slope of the characters in the signatures.
- Score from 0 to 10 based on how similar the angles and slopes are.
6. **Spacing and Proportions**:
- Compare the spacing between characters and the proportions of the signatures.
- Score from 0 to 10 based on the similarity of spacing and proportions.
After scoring each dimension, calculate the final similarity score by averaging the scores from all dimensions. The final similarity score should be a value between 0 and 10, where 0 indicates no similarity and 10 indicates identical signatures.
Here is an example output format for the comparison:
```
General Shape and Flow: 8
Consistency of Loops and Strokes: 7
Signature Characteristics: 6
Stroke Pressure and Line Thickness: 5
Angle and Slope: 8
Spacing and Proportions: 7
Final Similarity Score: 6.83
```
Please help with this comparison and scoring for the two provided signatures.
Important: Only score based on visual intuition, do not run code or provide coding solutions
'''
EXAMPLES = [
[{"text": prompt_extraction, "files": ["./IMG_2700.png"]}],
[{"text": prompt_signature , "files": ["./2.jpg"]}],
[{"text": "Is it real?", "files": ["./spacecat.png"]}]
]
with gr.Blocks(css=CSS, theme="soft",fill_height=True) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="โš™๏ธ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=1024,
maximum=8192,
step=1,
value=4096,
label="Max Length",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=10,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="Repetition penalty",
render=False,
),
],
),
gr.Examples(EXAMPLES,[chat_input])
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False)