File size: 8,669 Bytes
1cd9f06
657585b
210ec4a
d8164ce
1cd9f06
 
 
 
 
844af01
210ec4a
 
 
 
 
657585b
1cd9f06
d8164ce
 
1cd9f06
d4f5d88
1cd9f06
 
 
 
 
 
 
 
 
 
 
a2c455a
 
 
 
 
 
 
 
 
1cd9f06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f09252
1cd9f06
 
 
 
 
 
 
9f09252
 
 
 
 
 
 
 
 
 
 
 
1cd9f06
 
 
 
 
9f09252
1cd9f06
9f09252
1cd9f06
 
 
 
 
 
 
 
 
 
 
 
 
 
9f09252
1cd9f06
9f09252
1cd9f06
 
 
9f09252
10be371
 
9f09252
10be371
9f09252
 
10be371
 
 
 
 
 
9f09252
10be371
 
9f09252
10be371
00eb905
 
 
9f09252
10be371
 
9f09252
 
 
 
 
10be371
9f09252
1cd9f06
10be371
9f09252
 
 
10be371
9f09252
1cd9f06
 
 
9f09252
10be371
9f09252
 
 
 
 
 
 
 
 
 
 
 
 
1cd9f06
 
9f09252
 
 
 
 
 
 
 
 
 
 
 
9b61493
10be371
 
 
1cd9f06
 
10be371
 
1cd9f06
f0d2584
1cd9f06
 
8f9fe18
1cd9f06
9f09252
210ec4a
1cd9f06
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# app.py
import os
from pathlib import Path
import torch
from threading import Event, Thread
from typing import List, Tuple

# Importing necessary packages
from transformers import AutoConfig, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from langchain_community.tools import DuckDuckGoSearchRun
from optimum.intel.openvino import OVModelForCausalLM
import openvino as ov
import openvino.properties as props
import openvino.properties.hint as hints
import openvino.properties.streams as streams

from gradio_helper import make_demo  # UI logic import
from llm_config import SUPPORTED_LLM_MODELS

# Model configuration setup
max_new_tokens = 256
model_language_value = "English"
model_id_value = 'qwen2.5-0.5b-instruct'
prepare_int4_model_value = True
enable_awq_value = False
device_value = 'CPU'
model_to_run_value = 'INT4'
pt_model_id = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]["model_id"]
pt_model_name = model_id_value.split("-")[0]
int4_model_dir = Path(model_id_value) / "INT4_compressed_weights"
int4_weights = int4_model_dir / "openvino_model.bin"

model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]
model_name = model_configuration["model_id"]
start_message = model_configuration["start_message"]
history_template = model_configuration.get("history_template")
has_chat_template = model_configuration.get("has_chat_template", history_template is None)
current_message_template = model_configuration.get("current_message_template")
stop_tokens = model_configuration.get("stop_tokens")
tokenizer_kwargs = model_configuration.get("tokenizer_kwargs", {})

# Model loading
core = ov.Core()
ov_config = {
    hints.performance_mode(): hints.PerformanceMode.LATENCY,
    streams.num(): "1",
    props.cache_dir(): ""
}
tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True)
ov_model = OVModelForCausalLM.from_pretrained(
    int4_model_dir,
    device=device_value,
    ov_config=ov_config,
    config=AutoConfig.from_pretrained(int4_model_dir, trust_remote_code=True),
    trust_remote_code=True,
)

# Define stopping criteria for specific token sequences
class StopOnTokens(StoppingCriteria):
    def __init__(self, token_ids):
        self.token_ids = token_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids)

if stop_tokens is not None:
    if isinstance(stop_tokens[0], str):
        stop_tokens = tok.convert_tokens_to_ids(stop_tokens)
    stop_tokens = [StopOnTokens(stop_tokens)]

# Helper function for partial text update
def default_partial_text_processor(partial_text: str, new_text: str) -> str:
    return partial_text + new_text

text_processor = model_configuration.get("partial_text_processor", default_partial_text_processor)

# Convert conversation history to tokens based on model template
def convert_history_to_token(history: List[Tuple[str, str]]):
    if pt_model_name == "baichuan2":
        system_tokens = tok.encode(start_message)
        history_tokens = []
        for old_query, response in history[:-1]:
            round_tokens = [195] + tok.encode(old_query) + [196] + tok.encode(response)
            history_tokens = round_tokens + history_tokens
        input_tokens = system_tokens + history_tokens + [195] + tok.encode(history[-1][0]) + [196]
        input_token = torch.LongTensor([input_tokens])
    elif history_template is None or has_chat_template:
        messages = [{"role": "system", "content": start_message}]
        for idx, (user_msg, model_msg) in enumerate(history):
            if idx == len(history) - 1 and not model_msg:
                messages.append({"role": "user", "content": user_msg})
                break
            if user_msg:
                messages.append({"role": "user", "content": user_msg})
            if model_msg:
                messages.append({"role": "assistant", "content": model_msg})
        input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt")
    else:
        text = start_message + "".join(
            [history_template.format(num=round, user=item[0], assistant=item[1]) for round, item in enumerate(history[:-1])]
        )
        text += current_message_template.format(num=len(history) + 1, user=history[-1][0], assistant=history[-1][1])
        input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids
    return input_token

# Initialize search tool
search = DuckDuckGoSearchRun()

# Determine if a search is needed based on the query
def should_use_search(query: str) -> bool:
    search_keywords = ["latest", "news", "update", "which", "who", "what", "when", "why", "how", "recent", "current",
                      "announcement", "bulletin", "report", "brief", "insight", "disclosure", "update", 
                        "release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate", 
                        "recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate", 
                        "explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define", 
                        "illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate", 
                        "break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion", 
                        "product", "performance", "resolution"
                      ]
    return any(keyword in query.lower() for keyword in search_keywords)

# Construct the prompt with optional search context
def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str:
    instructions = (
        "Based on the information provided below, deliver an accurate, concise, and easily understandable answer. If relevant information is missing, draw on your general knowledge and mention the absence of specific details."
    )
    prompt = f"{instructions}\n\n{search_context if search_context else ''}\n\n{user_query} ?\n\n"
    return prompt

# Fetch search results for a query
def fetch_search_results(query: str) -> str:
    search_results = search.invoke(query)
    print("Search results:", search_results)  # Optional: Debugging output
    return f"Relevant and recent information:\n{search_results}"

# Main chatbot function
def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
    user_query = history[-1][0]
    search_context = fetch_search_results(user_query) if should_use_search(user_query) else ""
    prompt = construct_model_prompt(user_query, search_context, history)
    input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids if search_context else convert_history_to_token(history)

    # Limit input length to avoid exceeding token limit
    if input_ids.shape[1] > 2000:
        history = [history[-1]]

    # Configure response streaming
    streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = {
        "input_ids": input_ids,
        "max_new_tokens": max_new_tokens,
        "temperature": temperature,
        "do_sample": temperature > 0.0,
        "top_p": top_p,
        "top_k": top_k,
        "repetition_penalty": repetition_penalty,
        "streamer": streamer,
        "stopping_criteria": StoppingCriteriaList(stop_tokens) if stop_tokens is not None else None,
    }

    # Signal completion
    stream_complete = Event()
    def generate_and_signal_complete():
        try:
            ov_model.generate(**generate_kwargs)
        except RuntimeError as e:
            # Check if the error message indicates the request was canceled
            if "Infer Request was canceled" in str(e):
                print("Generation request was canceled.")
            else:
                # If it's a different RuntimeError, re-raise it
                raise e
        finally:
            # Signal completion of the stream
            stream_complete.set()

    t1 = Thread(target=generate_and_signal_complete)
    t1.start()

    partial_text = ""
    for new_text in streamer:
        partial_text = text_processor(partial_text, new_text)
        history[-1] = (user_query, partial_text)
        yield history

def request_cancel():
    ov_model.request.cancel()

# Gradio setup and launch
demo = make_demo(run_fn=bot, title=f"OpenVINO Search & Reasoning Chatbot", language=model_language_value)
if __name__ == "__main__":
    demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)