llm-chatbot / app.py
lightmate's picture
Update app.py
7fc772b verified
raw
history blame
6.83 kB
import os
import torch
import gradio as gr
from pathlib import Path
from transformers import AutoConfig, AutoTokenizer
from optimum.intel.openvino import OVModelForCausalLM
from typing import List, Tuple
from threading import Event, Thread
from gradio_helper import make_demo # Your helper function for Gradio demo
from llm_config import SUPPORTED_LLM_MODELS # Model configuration
from notebook_utils import device_widget # Device selection utility
import openvino as ov
import openvino.properties as props
import openvino.properties.hint as hints
import openvino.properties.streams as streams
import requests
# Define the model loading function (same as in your notebook)
def convert_to_int4(model_id, model_configuration, enable_awq=False):
compression_configs = {
"qwen2.5-0.5b-instruct": {"sym": True, "group_size": 128, "ratio": 1.0},
"default": {"sym": False, "group_size": 128, "ratio": 0.8},
}
model_compression_params = compression_configs.get(model_id, compression_configs["default"])
# Example conversion logic
int4_model_dir = Path(model_id) / "INT4_compressed_weights"
if (int4_model_dir / "openvino_model.xml").exists():
return int4_model_dir
remote_code = model_configuration.get("remote_code", False)
export_command_base = f"optimum-cli export openvino --model {model_configuration['model_id']} --task text-generation-with-past --weight-format int4"
int4_compression_args = f" --group-size {model_compression_params['group_size']} --ratio {model_compression_params['ratio']}"
if model_compression_params["sym"]:
int4_compression_args += " --sym"
if enable_awq:
int4_compression_args += " --awq --dataset wikitext2 --num-samples 128"
export_command_base += int4_compression_args
if remote_code:
export_command_base += " --trust-remote-code"
export_command = export_command_base + f" {str(int4_model_dir)}"
# Execute export command (shell command)
os.system(export_command)
return int4_model_dir
# Model and tokenizer loading
def load_model(model_dir, device):
ov_config = {hints.performance_mode(): hints.PerformanceMode.LATENCY, streams.num(): "1", props.cache_dir(): ""}
core = ov.Core()
model_name = model_configuration["model_id"]
tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
ov_model = OVModelForCausalLM.from_pretrained(
model_dir,
device=device,
ov_config=ov_config,
config=AutoConfig.from_pretrained(model_dir, trust_remote_code=True),
trust_remote_code=True,
)
return ov_model, tok
# Gradio Interface for Bot interaction
def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
input_ids = convert_history_to_token(history)
if input_ids.shape[1] > 2000:
history = [history[-1]] # Limit input size
input_ids = convert_history_to_token(history)
streamer = TextIteratorStreamer(tok, timeout=3600.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=256,
temperature=temperature,
do_sample=temperature > 0.0,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
streamer=streamer,
)
# Function to generate response in a separate thread
def generate_and_signal_complete():
ov_model.generate(**generate_kwargs)
stream_complete.set()
t1 = Thread(target=generate_and_signal_complete)
t1.start()
# Process partial text and return updated history
partial_text = ""
for new_text in streamer:
partial_text = text_processor(partial_text, new_text)
history[-1][1] = partial_text
yield history
# Define a Gradio interface for user interaction
def create_gradio_interface():
# Dropdown for selecting model language and model ID
model_language = list(SUPPORTED_LLM_MODELS.keys()) # List of model languages
model_id = gr.Dropdown(choices=model_language, value=model_language[0], label="Model Language")
# Once model language is selected, show the respective model IDs
def update_model_ids(model_language):
model_ids = list(SUPPORTED_LLM_MODELS[model_language].keys())
return gr.Dropdown.update(choices=model_ids, value=model_ids[0])
model_id_selector = gr.Dropdown(choices=model_language, value=model_language[0], label="Model ID")
# Set up a checkbox for enabling AWQ compression
enable_awq = gr.Checkbox(value=False, label="Enable AWQ for Compression")
# Initialize model selection based on language and ID
def load_model_on_select(model_language, model_id, enable_awq):
model_configuration = SUPPORTED_LLM_MODELS[model_language][model_id]
int4_model_dir = convert_to_int4(model_id, model_configuration, enable_awq)
# Load the model and tokenizer
device = device_widget("CPU") # or any device you want to use
ov_model, tok = load_model(int4_model_dir, device)
# Return the loaded model and tokenizer
return ov_model, tok
# Create the Gradio chatbot interface
chatbot = gr.Chatbot()
# Parameters for bot generation
temperature = gr.Slider(minimum=0, maximum=1, step=0.1, label="Temperature", value=0.7)
top_p = gr.Slider(minimum=0, maximum=1, step=0.1, label="Top-p", value=0.9)
top_k = gr.Slider(minimum=0, maximum=50, step=1, label="Top-k", value=50)
repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.1, label="Repetition Penalty", value=1.0)
with gr.Blocks() as demo:
# Create the Gradio components and add them to the Blocks context
model_id_selector.change(update_model_ids, inputs=model_language, outputs=model_id_selector)
load_button = gr.Button("Load Model")
load_button.click(load_model_on_select, inputs=[model_language, model_id, enable_awq], outputs=[gr.Textbox(label="Model Status")])
# Set up the chatbot UI with all the required components
gr.Row([model_id_selector, enable_awq]) # Arrange the dropdowns and checkbox in a row
gr.Row([load_button]) # Add the button below the inputs
gr.Row([chatbot]) # Add the chatbot output
# Parameters for generation
gr.Row([temperature, top_p, top_k, repetition_penalty]) # Add sliders in a row
# Define bot function and run the interface
demo.queue() # This is used to queue inputs and outputs, handling concurrent generation calls
demo.launch(debug=True, share=True) # For public access
return demo
# Run the Gradio app
if __name__ == "__main__":
app = create_gradio_interface()
app.launch(debug=True, share=True) # share=True for public access