llm-chatbot / app.py
lightmate's picture
Update app.py
4acb2ad verified
raw
history blame
4.68 kB
import os
from pathlib import Path
import requests
import shutil
import torch
from threading import Event, Thread
from transformers import AutoConfig, AutoTokenizer
from optimum.intel.openvino import OVModelForCausalLM
import openvino as ov
import openvino.properties as props
import openvino.properties.hint as hints
import openvino.properties.streams as streams
import gradio as gr
from llm_config import SUPPORTED_LLM_MODELS
from notebook_utils import device_widget
# Initialize model language options
model_languages = list(SUPPORTED_LLM_MODELS)
def update_model_id(model_language_value):
model_ids = list(SUPPORTED_LLM_MODELS[model_language_value])
return model_ids[0], gr.update(choices=model_ids)
# Function to download the model if not already present
def download_model_if_needed(model_language_value, model_id_value):
model_configuration, int4_model_dir, pt_model_name = get_model_path(model_language_value, model_id_value)
int4_weights = int4_model_dir / "openvino_model.bin"
if not int4_weights.exists():
print(f"Downloading model {model_id_value}...")
# Add your download logic here (e.g., from a URL)
# Example:
# r = requests.get(model_configuration["model_url"])
# with open(int4_weights, "wb") as f:
# f.write(r.content)
return int4_model_dir
# Load the model
def load_model(model_language_value, model_id_value):
int4_model_dir = download_model_if_needed(model_language_value, model_id_value)
ov_config = {hints.performance_mode(): hints.PerformanceMode.LATENCY, streams.num(): "1", props.cache_dir(): ""}
core = ov.Core()
model_dir = int4_model_dir
model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]
tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
ov_model = OVModelForCausalLM.from_pretrained(
model_dir,
device=device.value,
ov_config=ov_config,
config=AutoConfig.from_pretrained(model_dir, trust_remote_code=True),
trust_remote_code=True
)
return tok, ov_model, model_configuration
# Gradio interface function for generating text responses
def generate_response(history, temperature, top_p, top_k, repetition_penalty, model_language_value, model_id_value):
tok, ov_model, model_configuration = load_model(model_language_value, model_id_value)
input_ids = tok(" ".join([msg[0] for msg in history]), return_tensors="pt").input_ids
streamer = gr.Textbox.update()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=256,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
streamer=streamer
)
event = Event()
def generate_and_signal_complete():
ov_model.generate(**generate_kwargs)
event.set()
t1 = Thread(target=generate_and_signal_complete)
t1.start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
history[-1][1] = partial_text
yield history
# Gradio UI within a Blocks context
with gr.Blocks() as iface:
model_language = gr.Dropdown(
choices=model_languages,
value=model_languages[0],
label="Model Language"
)
model_id = gr.Dropdown(
choices=[], # dynamically populated
label="Model",
value=None
)
model_language.change(update_model_id, inputs=model_language, outputs=[model_id])
prepare_int4_model = gr.Checkbox(
value=True,
label="Prepare INT4 Model"
)
enable_awq = gr.Checkbox(
value=False,
label="Enable AWQ",
visible=False
)
device = device_widget("CPU", exclude=["NPU"])
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P")
top_k = gr.Slider(minimum=0, maximum=50, value=50, label="Top K")
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, label="Repetition Penalty")
history = gr.State([])
iface_interface = gr.Interface(
fn=generate_response,
inputs=[
history,
temperature,
top_p,
top_k,
repetition_penalty,
model_language,
model_id
],
outputs=[gr.Textbox(label="Conversation History")],
live=True,
title="OpenVINO Chatbot"
)
iface_interface.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
iface.launch()