Spaces:
Running
Running
File size: 8,669 Bytes
1cd9f06 657585b 210ec4a d8164ce 1cd9f06 844af01 210ec4a 657585b 1cd9f06 d8164ce 1cd9f06 d4f5d88 1cd9f06 a2c455a 1cd9f06 9f09252 1cd9f06 9f09252 1cd9f06 9f09252 1cd9f06 9f09252 1cd9f06 9f09252 1cd9f06 9f09252 1cd9f06 9f09252 10be371 9f09252 10be371 9f09252 10be371 9f09252 10be371 9f09252 10be371 00eb905 9f09252 10be371 9f09252 10be371 9f09252 1cd9f06 10be371 9f09252 10be371 9f09252 1cd9f06 9f09252 10be371 9f09252 1cd9f06 9f09252 9b61493 10be371 1cd9f06 10be371 1cd9f06 f0d2584 1cd9f06 8f9fe18 1cd9f06 9f09252 210ec4a 1cd9f06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# app.py
import os
from pathlib import Path
import torch
from threading import Event, Thread
from typing import List, Tuple
# Importing necessary packages
from transformers import AutoConfig, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from langchain_community.tools import DuckDuckGoSearchRun
from optimum.intel.openvino import OVModelForCausalLM
import openvino as ov
import openvino.properties as props
import openvino.properties.hint as hints
import openvino.properties.streams as streams
from gradio_helper import make_demo # UI logic import
from llm_config import SUPPORTED_LLM_MODELS
# Model configuration setup
max_new_tokens = 256
model_language_value = "English"
model_id_value = 'qwen2.5-0.5b-instruct'
prepare_int4_model_value = True
enable_awq_value = False
device_value = 'CPU'
model_to_run_value = 'INT4'
pt_model_id = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]["model_id"]
pt_model_name = model_id_value.split("-")[0]
int4_model_dir = Path(model_id_value) / "INT4_compressed_weights"
int4_weights = int4_model_dir / "openvino_model.bin"
model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]
model_name = model_configuration["model_id"]
start_message = model_configuration["start_message"]
history_template = model_configuration.get("history_template")
has_chat_template = model_configuration.get("has_chat_template", history_template is None)
current_message_template = model_configuration.get("current_message_template")
stop_tokens = model_configuration.get("stop_tokens")
tokenizer_kwargs = model_configuration.get("tokenizer_kwargs", {})
# Model loading
core = ov.Core()
ov_config = {
hints.performance_mode(): hints.PerformanceMode.LATENCY,
streams.num(): "1",
props.cache_dir(): ""
}
tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True)
ov_model = OVModelForCausalLM.from_pretrained(
int4_model_dir,
device=device_value,
ov_config=ov_config,
config=AutoConfig.from_pretrained(int4_model_dir, trust_remote_code=True),
trust_remote_code=True,
)
# Define stopping criteria for specific token sequences
class StopOnTokens(StoppingCriteria):
def __init__(self, token_ids):
self.token_ids = token_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids)
if stop_tokens is not None:
if isinstance(stop_tokens[0], str):
stop_tokens = tok.convert_tokens_to_ids(stop_tokens)
stop_tokens = [StopOnTokens(stop_tokens)]
# Helper function for partial text update
def default_partial_text_processor(partial_text: str, new_text: str) -> str:
return partial_text + new_text
text_processor = model_configuration.get("partial_text_processor", default_partial_text_processor)
# Convert conversation history to tokens based on model template
def convert_history_to_token(history: List[Tuple[str, str]]):
if pt_model_name == "baichuan2":
system_tokens = tok.encode(start_message)
history_tokens = []
for old_query, response in history[:-1]:
round_tokens = [195] + tok.encode(old_query) + [196] + tok.encode(response)
history_tokens = round_tokens + history_tokens
input_tokens = system_tokens + history_tokens + [195] + tok.encode(history[-1][0]) + [196]
input_token = torch.LongTensor([input_tokens])
elif history_template is None or has_chat_template:
messages = [{"role": "system", "content": start_message}]
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt")
else:
text = start_message + "".join(
[history_template.format(num=round, user=item[0], assistant=item[1]) for round, item in enumerate(history[:-1])]
)
text += current_message_template.format(num=len(history) + 1, user=history[-1][0], assistant=history[-1][1])
input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids
return input_token
# Initialize search tool
search = DuckDuckGoSearchRun()
# Determine if a search is needed based on the query
def should_use_search(query: str) -> bool:
search_keywords = ["latest", "news", "update", "which", "who", "what", "when", "why", "how", "recent", "current",
"announcement", "bulletin", "report", "brief", "insight", "disclosure", "update",
"release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate",
"recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate",
"explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define",
"illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate",
"break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion",
"product", "performance", "resolution"
]
return any(keyword in query.lower() for keyword in search_keywords)
# Construct the prompt with optional search context
def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str:
instructions = (
"Based on the information provided below, deliver an accurate, concise, and easily understandable answer. If relevant information is missing, draw on your general knowledge and mention the absence of specific details."
)
prompt = f"{instructions}\n\n{search_context if search_context else ''}\n\n{user_query} ?\n\n"
return prompt
# Fetch search results for a query
def fetch_search_results(query: str) -> str:
search_results = search.invoke(query)
print("Search results:", search_results) # Optional: Debugging output
return f"Relevant and recent information:\n{search_results}"
# Main chatbot function
def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
user_query = history[-1][0]
search_context = fetch_search_results(user_query) if should_use_search(user_query) else ""
prompt = construct_model_prompt(user_query, search_context, history)
input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids if search_context else convert_history_to_token(history)
# Limit input length to avoid exceeding token limit
if input_ids.shape[1] > 2000:
history = [history[-1]]
# Configure response streaming
streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": input_ids,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": temperature > 0.0,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"streamer": streamer,
"stopping_criteria": StoppingCriteriaList(stop_tokens) if stop_tokens is not None else None,
}
# Signal completion
stream_complete = Event()
def generate_and_signal_complete():
try:
ov_model.generate(**generate_kwargs)
except RuntimeError as e:
# Check if the error message indicates the request was canceled
if "Infer Request was canceled" in str(e):
print("Generation request was canceled.")
else:
# If it's a different RuntimeError, re-raise it
raise e
finally:
# Signal completion of the stream
stream_complete.set()
t1 = Thread(target=generate_and_signal_complete)
t1.start()
partial_text = ""
for new_text in streamer:
partial_text = text_processor(partial_text, new_text)
history[-1] = (user_query, partial_text)
yield history
def request_cancel():
ov_model.request.cancel()
# Gradio setup and launch
demo = make_demo(run_fn=bot, title=f"OpenVINO Search & Reasoning Chatbot", language=model_language_value)
if __name__ == "__main__":
demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)
|