Spaces:
Running
Running
import os | |
import torch | |
import gradio as gr | |
from pathlib import Path | |
from transformers import AutoConfig, AutoTokenizer | |
from optimum.intel.openvino import OVModelForCausalLM | |
from typing import List, Tuple | |
from threading import Event, Thread | |
from gradio_helper import make_demo # Your helper function for Gradio demo | |
from llm_config import SUPPORTED_LLM_MODELS # Model configuration | |
from notebook_utils import device_widget # Device selection utility | |
import openvino as ov | |
import openvino.properties as props | |
import openvino.properties.hint as hints | |
import openvino.properties.streams as streams | |
import requests | |
# Define the model loading function (same as in your notebook) | |
def convert_to_int4(model_id, model_configuration, enable_awq=False): | |
compression_configs = { | |
"qwen2.5-0.5b-instruct": {"sym": True, "group_size": 128, "ratio": 1.0}, | |
"default": {"sym": False, "group_size": 128, "ratio": 0.8}, | |
} | |
model_compression_params = compression_configs.get(model_id, compression_configs["default"]) | |
# Example conversion logic | |
int4_model_dir = Path(model_id) / "INT4_compressed_weights" | |
if (int4_model_dir / "openvino_model.xml").exists(): | |
return int4_model_dir | |
remote_code = model_configuration.get("remote_code", False) | |
export_command_base = f"optimum-cli export openvino --model {model_configuration['model_id']} --task text-generation-with-past --weight-format int4" | |
int4_compression_args = f" --group-size {model_compression_params['group_size']} --ratio {model_compression_params['ratio']}" | |
if model_compression_params["sym"]: | |
int4_compression_args += " --sym" | |
if enable_awq: | |
int4_compression_args += " --awq --dataset wikitext2 --num-samples 128" | |
export_command_base += int4_compression_args | |
if remote_code: | |
export_command_base += " --trust-remote-code" | |
export_command = export_command_base + f" {str(int4_model_dir)}" | |
# Execute export command (shell command) | |
os.system(export_command) | |
return int4_model_dir | |
# Model and tokenizer loading | |
def load_model(model_dir, device): | |
ov_config = {hints.performance_mode(): hints.PerformanceMode.LATENCY, streams.num(): "1", props.cache_dir(): ""} | |
core = ov.Core() | |
model_name = model_configuration["model_id"] | |
tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) | |
ov_model = OVModelForCausalLM.from_pretrained( | |
model_dir, | |
device=device, | |
ov_config=ov_config, | |
config=AutoConfig.from_pretrained(model_dir, trust_remote_code=True), | |
trust_remote_code=True, | |
) | |
return ov_model, tok | |
# Gradio Interface for Bot interaction | |
def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): | |
input_ids = convert_history_to_token(history) | |
if input_ids.shape[1] > 2000: | |
history = [history[-1]] # Limit input size | |
input_ids = convert_history_to_token(history) | |
streamer = TextIteratorStreamer(tok, timeout=3600.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
max_new_tokens=256, | |
temperature=temperature, | |
do_sample=temperature > 0.0, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
streamer=streamer, | |
) | |
# Function to generate response in a separate thread | |
def generate_and_signal_complete(): | |
ov_model.generate(**generate_kwargs) | |
stream_complete.set() | |
t1 = Thread(target=generate_and_signal_complete) | |
t1.start() | |
# Process partial text and return updated history | |
partial_text = "" | |
for new_text in streamer: | |
partial_text = text_processor(partial_text, new_text) | |
history[-1][1] = partial_text | |
yield history | |
# Define a Gradio interface for user interaction | |
def create_gradio_interface(): | |
# Dropdown for selecting model language and model ID | |
model_language = list(SUPPORTED_LLM_MODELS.keys()) # List of model languages | |
model_id = gr.Dropdown(choices=model_language, value=model_language[0], label="Model Language") | |
# Once model language is selected, show the respective model IDs | |
def update_model_ids(model_language): | |
model_ids = list(SUPPORTED_LLM_MODELS[model_language].keys()) | |
return gr.Dropdown.update(choices=model_ids, value=model_ids[0]) | |
model_id_selector = gr.Dropdown(choices=model_language, value=model_language[0], label="Model ID") | |
# Set up a checkbox for enabling AWQ compression | |
enable_awq = gr.Checkbox(value=False, label="Enable AWQ for Compression") | |
# Initialize model selection based on language and ID | |
def load_model_on_select(model_language, model_id, enable_awq): | |
model_configuration = SUPPORTED_LLM_MODELS[model_language][model_id] | |
int4_model_dir = convert_to_int4(model_id, model_configuration, enable_awq) | |
# Load the model and tokenizer | |
device = device_widget("CPU") # or any device you want to use | |
ov_model, tok = load_model(int4_model_dir, device) | |
# Return the loaded model and tokenizer | |
return ov_model, tok | |
# Create the Gradio chatbot interface | |
chatbot = gr.Chatbot() | |
# Parameters for bot generation | |
temperature = gr.Slider(minimum=0, maximum=1, step=0.1, label="Temperature", value=0.7) | |
top_p = gr.Slider(minimum=0, maximum=1, step=0.1, label="Top-p", value=0.9) | |
top_k = gr.Slider(minimum=0, maximum=50, step=1, label="Top-k", value=50) | |
repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.1, label="Repetition Penalty", value=1.0) | |
with gr.Blocks() as demo: | |
# Create the Gradio components and add them to the Blocks context | |
model_id_selector.change(update_model_ids, inputs=model_language, outputs=model_id_selector) | |
load_button = gr.Button("Load Model") | |
load_button.click(load_model_on_select, inputs=[model_language, model_id, enable_awq], outputs=[gr.Textbox(label="Model Status")]) | |
# Set up the chatbot UI with all the required components | |
gr.Row([model_id_selector, enable_awq]) # Arrange the dropdowns and checkbox in a row | |
gr.Row([load_button]) # Add the button below the inputs | |
gr.Row([chatbot]) # Add the chatbot output | |
# Parameters for generation | |
gr.Row([temperature, top_p, top_k, repetition_penalty]) # Add sliders in a row | |
# Define bot function and run the interface | |
demo.queue() # This is used to queue inputs and outputs, handling concurrent generation calls | |
demo.launch(debug=True, share=True) # For public access | |
return demo | |
# Run the Gradio app | |
if __name__ == "__main__": | |
app = create_gradio_interface() | |
app.launch(debug=True, share=True) # share=True for public access | |