llm-chatbot / app.py
lightmate's picture
add real time search
00eb905 verified
raw
history blame
8.67 kB
# app.py
import os
from pathlib import Path
import torch
from threading import Event, Thread
from typing import List, Tuple
# Importing necessary packages
from transformers import AutoConfig, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from langchain_community.tools import DuckDuckGoSearchRun
from optimum.intel.openvino import OVModelForCausalLM
import openvino as ov
import openvino.properties as props
import openvino.properties.hint as hints
import openvino.properties.streams as streams
from gradio_helper import make_demo # UI logic import
from llm_config import SUPPORTED_LLM_MODELS
# Model configuration setup
max_new_tokens = 256
model_language_value = "English"
model_id_value = 'qwen2.5-0.5b-instruct'
prepare_int4_model_value = True
enable_awq_value = False
device_value = 'CPU'
model_to_run_value = 'INT4'
pt_model_id = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]["model_id"]
pt_model_name = model_id_value.split("-")[0]
int4_model_dir = Path(model_id_value) / "INT4_compressed_weights"
int4_weights = int4_model_dir / "openvino_model.bin"
model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]
model_name = model_configuration["model_id"]
start_message = model_configuration["start_message"]
history_template = model_configuration.get("history_template")
has_chat_template = model_configuration.get("has_chat_template", history_template is None)
current_message_template = model_configuration.get("current_message_template")
stop_tokens = model_configuration.get("stop_tokens")
tokenizer_kwargs = model_configuration.get("tokenizer_kwargs", {})
# Model loading
core = ov.Core()
ov_config = {
hints.performance_mode(): hints.PerformanceMode.LATENCY,
streams.num(): "1",
props.cache_dir(): ""
}
tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True)
ov_model = OVModelForCausalLM.from_pretrained(
int4_model_dir,
device=device_value,
ov_config=ov_config,
config=AutoConfig.from_pretrained(int4_model_dir, trust_remote_code=True),
trust_remote_code=True,
)
# Define stopping criteria for specific token sequences
class StopOnTokens(StoppingCriteria):
def __init__(self, token_ids):
self.token_ids = token_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids)
if stop_tokens is not None:
if isinstance(stop_tokens[0], str):
stop_tokens = tok.convert_tokens_to_ids(stop_tokens)
stop_tokens = [StopOnTokens(stop_tokens)]
# Helper function for partial text update
def default_partial_text_processor(partial_text: str, new_text: str) -> str:
return partial_text + new_text
text_processor = model_configuration.get("partial_text_processor", default_partial_text_processor)
# Convert conversation history to tokens based on model template
def convert_history_to_token(history: List[Tuple[str, str]]):
if pt_model_name == "baichuan2":
system_tokens = tok.encode(start_message)
history_tokens = []
for old_query, response in history[:-1]:
round_tokens = [195] + tok.encode(old_query) + [196] + tok.encode(response)
history_tokens = round_tokens + history_tokens
input_tokens = system_tokens + history_tokens + [195] + tok.encode(history[-1][0]) + [196]
input_token = torch.LongTensor([input_tokens])
elif history_template is None or has_chat_template:
messages = [{"role": "system", "content": start_message}]
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt")
else:
text = start_message + "".join(
[history_template.format(num=round, user=item[0], assistant=item[1]) for round, item in enumerate(history[:-1])]
)
text += current_message_template.format(num=len(history) + 1, user=history[-1][0], assistant=history[-1][1])
input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids
return input_token
# Initialize search tool
search = DuckDuckGoSearchRun()
# Determine if a search is needed based on the query
def should_use_search(query: str) -> bool:
search_keywords = ["latest", "news", "update", "which", "who", "what", "when", "why", "how", "recent", "current",
"announcement", "bulletin", "report", "brief", "insight", "disclosure", "update",
"release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate",
"recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate",
"explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define",
"illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate",
"break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion",
"product", "performance", "resolution"
]
return any(keyword in query.lower() for keyword in search_keywords)
# Construct the prompt with optional search context
def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str:
instructions = (
"Based on the information provided below, deliver an accurate, concise, and easily understandable answer. If relevant information is missing, draw on your general knowledge and mention the absence of specific details."
)
prompt = f"{instructions}\n\n{search_context if search_context else ''}\n\n{user_query} ?\n\n"
return prompt
# Fetch search results for a query
def fetch_search_results(query: str) -> str:
search_results = search.invoke(query)
print("Search results:", search_results) # Optional: Debugging output
return f"Relevant and recent information:\n{search_results}"
# Main chatbot function
def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
user_query = history[-1][0]
search_context = fetch_search_results(user_query) if should_use_search(user_query) else ""
prompt = construct_model_prompt(user_query, search_context, history)
input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids if search_context else convert_history_to_token(history)
# Limit input length to avoid exceeding token limit
if input_ids.shape[1] > 2000:
history = [history[-1]]
# Configure response streaming
streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": input_ids,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": temperature > 0.0,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"streamer": streamer,
"stopping_criteria": StoppingCriteriaList(stop_tokens) if stop_tokens is not None else None,
}
# Signal completion
stream_complete = Event()
def generate_and_signal_complete():
try:
ov_model.generate(**generate_kwargs)
except RuntimeError as e:
# Check if the error message indicates the request was canceled
if "Infer Request was canceled" in str(e):
print("Generation request was canceled.")
else:
# If it's a different RuntimeError, re-raise it
raise e
finally:
# Signal completion of the stream
stream_complete.set()
t1 = Thread(target=generate_and_signal_complete)
t1.start()
partial_text = ""
for new_text in streamer:
partial_text = text_processor(partial_text, new_text)
history[-1] = (user_query, partial_text)
yield history
def request_cancel():
ov_model.request.cancel()
# Gradio setup and launch
demo = make_demo(run_fn=bot, title=f"OpenVINO Search & Reasoning Chatbot", language=model_language_value)
if __name__ == "__main__":
demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)