# app.py import os from pathlib import Path import torch from threading import Event, Thread from typing import List, Tuple # Importing necessary packages from transformers import AutoConfig, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer from langchain_community.tools import DuckDuckGoSearchRun from optimum.intel.openvino import OVModelForCausalLM import openvino as ov import openvino.properties as props import openvino.properties.hint as hints import openvino.properties.streams as streams from gradio_helper import make_demo # UI logic import from llm_config import SUPPORTED_LLM_MODELS # Model configuration setup max_new_tokens = 256 model_language_value = "English" model_id_value = 'qwen2.5-0.5b-instruct' prepare_int4_model_value = True enable_awq_value = False device_value = 'CPU' model_to_run_value = 'INT4' pt_model_id = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]["model_id"] pt_model_name = model_id_value.split("-")[0] int4_model_dir = Path(model_id_value) / "INT4_compressed_weights" int4_weights = int4_model_dir / "openvino_model.bin" model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value] model_name = model_configuration["model_id"] start_message = model_configuration["start_message"] history_template = model_configuration.get("history_template") has_chat_template = model_configuration.get("has_chat_template", history_template is None) current_message_template = model_configuration.get("current_message_template") stop_tokens = model_configuration.get("stop_tokens") tokenizer_kwargs = model_configuration.get("tokenizer_kwargs", {}) # Model loading core = ov.Core() ov_config = { hints.performance_mode(): hints.PerformanceMode.LATENCY, streams.num(): "1", props.cache_dir(): "" } tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True) ov_model = OVModelForCausalLM.from_pretrained( int4_model_dir, device=device_value, ov_config=ov_config, config=AutoConfig.from_pretrained(int4_model_dir, trust_remote_code=True), trust_remote_code=True, ) # Define stopping criteria for specific token sequences class StopOnTokens(StoppingCriteria): def __init__(self, token_ids): self.token_ids = token_ids def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids) if stop_tokens is not None: if isinstance(stop_tokens[0], str): stop_tokens = tok.convert_tokens_to_ids(stop_tokens) stop_tokens = [StopOnTokens(stop_tokens)] # Helper function for partial text update def default_partial_text_processor(partial_text: str, new_text: str) -> str: return partial_text + new_text text_processor = model_configuration.get("partial_text_processor", default_partial_text_processor) # Convert conversation history to tokens based on model template def convert_history_to_token(history: List[Tuple[str, str]]): if pt_model_name == "baichuan2": system_tokens = tok.encode(start_message) history_tokens = [] for old_query, response in history[:-1]: round_tokens = [195] + tok.encode(old_query) + [196] + tok.encode(response) history_tokens = round_tokens + history_tokens input_tokens = system_tokens + history_tokens + [195] + tok.encode(history[-1][0]) + [196] input_token = torch.LongTensor([input_tokens]) elif history_template is None or has_chat_template: messages = [{"role": "system", "content": start_message}] for idx, (user_msg, model_msg) in enumerate(history): if idx == len(history) - 1 and not model_msg: messages.append({"role": "user", "content": user_msg}) break if user_msg: messages.append({"role": "user", "content": user_msg}) if model_msg: messages.append({"role": "assistant", "content": model_msg}) input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt") else: text = start_message + "".join( [history_template.format(num=round, user=item[0], assistant=item[1]) for round, item in enumerate(history[:-1])] ) text += current_message_template.format(num=len(history) + 1, user=history[-1][0], assistant=history[-1][1]) input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids return input_token # Initialize search tool search = DuckDuckGoSearchRun() # Determine if a search is needed based on the query def should_use_search(query: str) -> bool: search_keywords = ["latest", "news", "update", "which", "who", "what", "when", "why", "how", "recent", "current", "announcement", "bulletin", "report", "brief", "insight", "disclosure", "update", "release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate", "recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate", "explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define", "illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate", "break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion", "product", "performance", "resolution" ] return any(keyword in query.lower() for keyword in search_keywords) # Construct the prompt with optional search context def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str: instructions = "Use the information below if relevant to provide an accurate and concise answer. If no information is available, rely on your general knowledge." prompt = f"{instructions}\n\n{search_context if search_context else ''}\n\n{user_query} ?\n\n" return prompt # Fetch search results for a query def fetch_search_results(query: str) -> str: search_results = search.invoke(query) print("Search results:", search_results) # Optional: Debugging output return f"Relevant and recent information:\n{search_results}" # Main chatbot function def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): user_query = history[-1][0] search_context = fetch_search_results(user_query) if should_use_search(user_query) else "" prompt = construct_model_prompt(user_query, search_context, history) input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids if search_context else convert_history_to_token(history) # Limit input length to avoid exceeding token limit if input_ids.shape[1] > 2000: history = [history[-1]] # Configure response streaming streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "input_ids": input_ids, "max_new_tokens": max_new_tokens, "temperature": temperature, "do_sample": temperature > 0.0, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "streamer": streamer, "stopping_criteria": StoppingCriteriaList(stop_tokens) if stop_tokens is not None else None, } # Signal completion stream_complete = Event() def generate_and_signal_complete(): try: ov_model.generate(**generate_kwargs) except RuntimeError as e: # Check if the error message indicates the request was canceled if "Infer Request was canceled" in str(e): print("Generation request was canceled.") else: # If it's a different RuntimeError, re-raise it raise e finally: # Signal completion of the stream stream_complete.set() t1 = Thread(target=generate_and_signal_complete) t1.start() partial_text = "" for new_text in streamer: partial_text = text_processor(partial_text, new_text) history[-1] = (user_query, partial_text) yield history def request_cancel(): ov_model.request.cancel() # Gradio setup and launch demo = make_demo(run_fn=bot, title=f"OpenVINO Search & Reasoning Chatbot", language=model_language_value) if __name__ == "__main__": demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)