Shanshan Wang
updated prompts in examples
2f6e568
import gradio as gr
from transformers import AutoModel, AutoTokenizer
import torch
import threading
import os
# caching the mode
model_cache = {}
tokenizer_cache = {}
model_lock = threading.Lock()
from huggingface_hub import login
hf_token = os.environ.get('hf_token', None)
# Define the models and their paths
model_paths = {
"H2OVL-Mississippi-2B":"h2oai/h2ovl-mississippi-2b",
"H2OVL-Mississippi-0.8B":"h2oai/h2ovl-mississippi-800m",
# Add more models as needed
}
example_prompts = [
"Read the text and provide word by word ocr for the document. <doc>",
"Read the text on the image",
"Extract the text from the image.",
"Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}",
"Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount",
]
# Function to handle task type logic
def handle_task_type(task_type, model_name):
max_new_tokens = 1024 # Default value
if task_type == "OCR":
max_new_tokens = 3072 # Adjust for OCR
return max_new_tokens
# Function to handle task type logic and default question
def handle_task_type_and_prompt(task_type, model_name):
max_new_tokens = handle_task_type(task_type, model_name)
default_question = example_prompts[0] if task_type == "OCR" else None
return max_new_tokens, default_question
def update_task_type_on_model_change(model_name):
# Set default task type and max_new_tokens based on the model
if '2b' in model_name.lower():
return "Document extractor", handle_task_type("Document extractor", model_name)
elif '0.8b' in model_name.lower():
return "OCR", handle_task_type("OCR", model_name)
else:
return "Chat", handle_task_type("Chat", model_name)
def load_model_and_set_image_function(model_name):
# Get the model path from the model_paths dictionary
model_path = model_paths[model_name]
with model_lock:
if model_name in model_cache:
# model is already loaded; retrieve it from the cache
print(f"Model {model_name} is already loaded. Retrieving from cache.")
else:
# load the model and tokenizer
print(f"Loading model {model_name}...")
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
use_auth_token=hf_token,
# device_map="auto"
).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
use_fast=False,
use_auth_token=hf_token
)
# add the model and tokenizer to the cache
model_cache[model_name] = model
tokenizer_cache[model_name] = tokenizer
print(f"Model {model_name} loaded successfully.")
return model_name
def inference(image_input,
user_message,
temperature,
top_p,
max_new_tokens,
tile_num,
chatbot,
state,
model_name):
# Check if model_state is None
if model_name is None:
chatbot.append(("System", "Please select a model to start the conversation."))
return chatbot, state, ""
with model_lock:
if model_name not in model_cache:
chatbot.append(("System", "Model not loaded. Please wait for the model to load."))
return chatbot, state, ""
model = model_cache[model_name]
tokenizer = tokenizer_cache[model_name]
# Check for empty or invalid user message
if not user_message or user_message.strip() == '' or user_message.lower() == 'system':
chatbot.append(("System", "Please enter a valid message to continue the conversation."))
return chatbot, state, ""
# if image is provided, store it in image_state:
if chatbot is None:
chatbot = []
if image_input is None:
chatbot.append(("System", "Please provide an image to start the conversation."))
return chatbot, state, ""
# Initialize history (state) if it's None
if state is None:
state = None # model.chat function handles None as empty history
# Append user message to chatbot
chatbot.append((user_message, None))
# Set generation config
do_sample = (float(temperature) != 0.0)
generation_config = dict(
num_beams=1,
max_new_tokens=int(max_new_tokens),
do_sample=do_sample,
temperature= float(temperature),
top_p= float(top_p),
)
# Call model.chat with history
if '2b' in model_name.lower():
response_text, new_state = model.chat(
tokenizer,
image_input,
user_message,
max_tiles = int(tile_num),
generation_config=generation_config,
history=state,
return_history=True
)
if '0.8b' in model_name.lower():
response_text, new_state = model.ocr(
tokenizer,
image_input,
user_message,
max_tiles = int(tile_num),
generation_config=generation_config,
history=state,
return_history=True
)
# update the satet with new_state
state = new_state
# Update chatbot with the model's response
chatbot[-1] = (user_message, response_text)
return chatbot, state, ""
def regenerate_response(chatbot,
temperature,
top_p,
max_new_tokens,
tile_num,
state,
image_input,
model_name):
# Check if model_state is None
if model_name is None:
chatbot.append(("System", "Please select a model to start the conversation."))
return chatbot, state
with model_lock:
if model_name not in model_cache:
chatbot.append(("System", "Model not loaded. Please wait for the model to load."))
return chatbot, state
model = model_cache[model_name]
tokenizer = tokenizer_cache[model_name]
# Check if there is a previous user message
if chatbot is None or len(chatbot) == 0:
chatbot = []
chatbot.append(("System", "Nothing to regenerate. Please start a conversation first."))
return chatbot, state,
# Get the last user message
last_user_message, _ = chatbot[-1]
# Check for empty or invalid last user message
if not last_user_message or last_user_message.strip() == '' or last_user_message.lower() == 'system':
chatbot.append(("System", "Cannot regenerate response for an empty or invalid message."))
return chatbot, state
# Remove last assistant's response from state
if state is not None and len(state) > 0:
state = state[:-1] # Remove last assistant's response from history
if len(state) == 0:
state = None
else:
state = None
# Set generation config
do_sample = (float(temperature) != 0.0)
generation_config = dict(
num_beams=1,
max_new_tokens=int(max_new_tokens),
do_sample=do_sample,
temperature= float(temperature),
top_p= float(top_p),
)
# Regenerate the response
if '2b' in model_name.lower():
response_text, new_state = model.chat(
tokenizer,
image_input,
last_user_message,
max_tiles = int(tile_num),
generation_config=generation_config,
history=state, # Exclude last assistant's response
return_history=True
)
if '0.8b' in model_name.lower():
response_text, new_state = model.ocr(
tokenizer,
image_input,
last_user_message,
max_tiles = int(tile_num),
generation_config=generation_config,
history=state, # Exclude last assistant's response
return_history=True
)
# Update the state with new_state
state = new_state
# Update chatbot with the regenerated response
chatbot[-1] = (last_user_message, response_text)
return chatbot, state
def clear_all():
return [], None, None, "" # Clear chatbot, state, reset image_input
title_html = """
<h1> <span class="gradient-text" id="text">H2OVL-Mississippi</span><span class="plain-text">: Lightweight Vision Language Models for OCR and Doc AI tasks</span></h1>
<a href="https://huggingface.co/collections/h2oai/h2ovl-mississippi-66e492da45da0a1b7ea7cf39">[😊 Hugging Face]</a>
<a href="https://arxiv.org/abs/2410.13611">[📜 Paper]</a>
<a href="https://huggingface.co/spaces/h2oai/h2ovl-mississippi-benchmarks">[🌟 Benchmarks]</a>
"""
# Build the Gradio interface
with gr.Blocks() as demo:
gr.HTML(title_html)
gr.HTML("""
<style>
.gradient-text {
font-size: 36px !important;
font-weight: bold !important;
}
.plain-text {
font-size: 32px !important;
}
h1 {
margin-bottom: 20px !important;
}
</style>
""")
state= gr.State()
model_state = gr.State()
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(model_paths.keys()),
label="Select Model",
value="H2OVL-Mississippi-2B"
)
task_type_dropdown = gr.Dropdown(
choices=["OCR", "Document extractor", "Chat"],
label="Select Task Type",
value="Document extractor"
)
with gr.Row(equal_height=True):
# First column with image input
with gr.Column(scale=1):
image_input = gr.Image(type="filepath", label="Upload an Image")
# Second column with chatbot and user input
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Conversation")
user_input = gr.Dropdown(label="What is your question",
choices = example_prompts,
value=None,
allow_custom_value=True,
interactive=True)
def reset_chatbot_state():
# reset chatbot and state
return [], None
# When the model selection changes, load the new model
model_dropdown.change(
fn=load_model_and_set_image_function,
inputs=[model_dropdown],
outputs=[model_state]
)
model_dropdown.change(
fn=reset_chatbot_state,
inputs=None,
outputs=[chatbot, state]
)
# Reset chatbot and state when image input changes
image_input.change(
fn=reset_chatbot_state,
inputs=None,
outputs=[chatbot, state]
)
# Load the default model when the app starts
demo.load(
fn=load_model_and_set_image_function,
inputs=[model_dropdown],
outputs=[model_state]
)
with gr.Accordion('Parameters', open=False):
with gr.Row():
temperature_input = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.2,
interactive=True,
label="Temperature")
top_p_input = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.9,
interactive=True,
label="Top P")
max_new_tokens_input = gr.Slider(
minimum=64,
maximum=4096,
step=64,
value=1024,
interactive=True,
label="Max New Tokens (default: 1024)")
tile_num = gr.Slider(
minimum=2,
maximum=12,
step=1,
value=6,
interactive=True,
label="Tile Number (default: 6)"
)
model_dropdown.change(
fn=update_task_type_on_model_change,
inputs=[model_dropdown],
outputs=[task_type_dropdown, max_new_tokens_input]
)
task_type_dropdown.change(
fn=handle_task_type_and_prompt,
inputs=[task_type_dropdown, model_dropdown],
outputs=[max_new_tokens_input, user_input]
)
with gr.Row():
submit_button = gr.Button("Submit")
regenerate_button = gr.Button("Regenerate")
clear_button = gr.Button("Clear")
# When the submit button is clicked, call the inference function
submit_button.click(
fn=inference,
inputs=[
image_input,
user_input,
temperature_input,
top_p_input,
max_new_tokens_input,
tile_num,
chatbot,
state,
model_state
],
outputs=[chatbot, state, user_input]
)
# When the regenerate button is clicked, re-run the last inference
regenerate_button.click(
fn=regenerate_response,
inputs=[
chatbot,
temperature_input,
top_p_input,
max_new_tokens_input,
tile_num,
state,
image_input,
model_state
],
outputs=[chatbot, state]
)
clear_button.click(
fn=clear_all,
inputs=None,
outputs=[chatbot, state, image_input, user_input]
)
def example_clicked(image_value, user_input_value):
chatbot_value, state_value = [], None
return image_value, user_input_value, chatbot_value, state_value # Reset chatbot and state
gr.Examples(
examples=[
["assets/handwritten-note-example.jpg", "Read the text and provide word by word ocr for the document. <doc>"],
["assets/rental_application.png", "Read the text and provide word by word ocr for the document. <doc>"],
["assets/receipt.jpg", "Read the text and provide word by word ocr for the document. <doc>"],
["assets/driver_license.png", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}"],
["assets/invoice.png", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount"],
["assets/CBA-1H23-Results-Presentation_wheel.png", "What is the efficiency of H2O.AI in document processing?"],
],
inputs = [image_input, user_input],
outputs = [image_input, user_input, chatbot, state],
fn=example_clicked,
label = "examples",
)
demo.queue()
demo.launch(max_threads=10)