llm-chatbot / app.py
lightmate's picture
Update app.py
d4f5d88 verified
raw
history blame
6.01 kB
# app.py
import os
from pathlib import Path
import torch
from threading import Event, Thread
from typing import List, Tuple
# Importing necessary packages
from transformers import AutoConfig, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from optimum.intel.openvino import OVModelForCausalLM
import openvino as ov
import openvino.properties as props
import openvino.properties.hint as hints
import openvino.properties.streams as streams
from gradio_helper import make_demo # UI logic import
from llm_config import SUPPORTED_LLM_MODELS
# Model configuration setup
model_name = model_configuration["model_id"]
start_message = model_configuration["start_message"]
history_template = model_configuration.get("history_template")
has_chat_template = model_configuration.get("has_chat_template", history_template is None)
current_message_template = model_configuration.get("current_message_template")
stop_tokens = model_configuration.get("stop_tokens")
tokenizer_kwargs = model_configuration.get("tokenizer_kwargs", {})
max_new_tokens = 256
model_language_value = "English"
model_id_value = 'qwen2.5-0.5b-instruct'
prepare_int4_model_value = True
enable_awq_value = False
device_value = 'CPU'
model_to_run_value = 'INT4'
pt_model_id = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]["model_id"]
pt_model_name = model_id_value.split("-")[0]
int4_model_dir = Path(model_id_value) / "INT4_compressed_weights"
int4_weights = int4_model_dir / "openvino_model.bin"
# Model loading
core = ov.Core()
ov_config = {
hints.performance_mode(): hints.PerformanceMode.LATENCY,
streams.num(): "1",
props.cache_dir(): ""
}
tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True)
ov_model = OVModelForCausalLM.from_pretrained(
int4_model_dir,
device=device_value,
ov_config=ov_config,
config=AutoConfig.from_pretrained(int4_model_dir, trust_remote_code=True),
trust_remote_code=True,
)
# Stopping criteria for token generation
class StopOnTokens(StoppingCriteria):
def __init__(self, token_ids):
self.token_ids = token_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids)
# Functions for chatbot logic
def convert_history_to_token(history: List[Tuple[str, str]]):
"""
function for conversion history stored as list pairs of user and assistant messages to tokens according to model expected conversation template
Params:
history: dialogue history
Returns:
history in token format
"""
if pt_model_name == "baichuan2":
system_tokens = tok.encode(start_message)
history_tokens = []
for old_query, response in history[:-1]:
round_tokens = []
round_tokens.append(195)
round_tokens.extend(tok.encode(old_query))
round_tokens.append(196)
round_tokens.extend(tok.encode(response))
history_tokens = round_tokens + history_tokens
input_tokens = system_tokens + history_tokens
input_tokens.append(195)
input_tokens.extend(tok.encode(history[-1][0]))
input_tokens.append(196)
input_token = torch.LongTensor([input_tokens])
elif history_template is None or has_chat_template:
messages = [{"role": "system", "content": start_message}]
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt")
else:
text = start_message + "".join(
["".join([history_template.format(num=round, user=item[0], assistant=item[1])]) for round, item in enumerate(history[:-1])]
)
text += "".join(
[
"".join(
[
current_message_template.format(
num=len(history) + 1,
user=history[-1][0],
assistant=history[-1][1],
)
]
)
]
)
input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids
return input_token
def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
# Callback function for running chatbot on submit button click
input_ids = convert_history_to_token(history)
if input_ids.shape[1] > 2000:
history = [history[-1]]
input_ids = convert_history_to_token(history)
streamer = TextIteratorStreamer(tok, timeout=3600.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=256,
temperature=temperature,
do_sample=temperature > 0.0,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
streamer=streamer,
)
stream_complete = Event()
def generate_and_signal_complete():
ov_model.generate(**generate_kwargs)
stream_complete.set()
Thread(target=generate_and_signal_complete).start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
history[-1][1] = partial_text
yield history
def request_cancel():
ov_model.request.cancel()
# Gradio setup and launch
demo = make_demo(run_fn=bot, stop_fn=request_cancel, title=f"OpenVINO {model_id_value} Chatbot", language=model_language_value)
if __name__ == "__main__":
demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)