lightmate commited on
Commit
657585b
1 Parent(s): 2390a8c

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +108 -0
  2. gradio_helper.py +175 -0
  3. llm_config.py +785 -0
  4. notebook_utils.py +715 -0
  5. requirements.txt +14 -0
app.py CHANGED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoConfig
4
+ from optimum.intel.openvino import OVModelForCausalLM
5
+ import openvino as ov
6
+ import gradio as gr
7
+ from gradio_helper import make_demo
8
+ from llm_config import SUPPORTED_LLM_MODELS
9
+ from pathlib import Path
10
+
11
+ # Define model configuration
12
+ model_language = "en" # Example: set to English
13
+ model_id = "qwen2.5-0.5b-instruct" # Example model ID
14
+
15
+ # Define model directories
16
+ pt_model_id = SUPPORTED_LLM_MODELS[model_language][model_id]["model_id"]
17
+ int4_model_dir = Path(model_id) / "INT4_compressed_weights"
18
+
19
+ # Load tokenizer
20
+ tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True)
21
+
22
+ # Ensure INT4 weights exist; if not, attempt conversion (locally before deployment)
23
+ def check_and_convert_model():
24
+ if not (int4_model_dir / "openvino_model.xml").exists():
25
+ print("INT4 model weights not found. Attempting compression...")
26
+ convert_to_int4()
27
+
28
+ def convert_to_int4():
29
+ """
30
+ Converts a model to INT4 precision using the optimum-cli tool.
31
+ This function should only be run locally or in an environment that supports shell commands.
32
+ """
33
+ # Define compression parameters
34
+ compression_configs = {
35
+ "qwen2.5-0.5b-instruct": {"sym": True, "group_size": 128, "ratio": 1.0},
36
+ "default": {"sym": False, "group_size": 128, "ratio": 0.8},
37
+ }
38
+
39
+ model_compression_params = compression_configs.get(model_id, compression_configs["default"])
40
+
41
+ # Check if the INT4 model already exists
42
+ if (int4_model_dir / "openvino_model.xml").exists():
43
+ print("INT4 model already exists.")
44
+ return # Exit if the model is already converted
45
+
46
+ # Run model compression using `optimum-cli`
47
+ export_command_base = f"optimum-cli export openvino --model {pt_model_id} --task text-generation-with-past --weight-format int4"
48
+ int4_compression_args = f" --group-size {model_compression_params['group_size']} --ratio {model_compression_params['ratio']}"
49
+ if model_compression_params["sym"]:
50
+ int4_compression_args += " --sym"
51
+
52
+ # You can add other custom compression arguments here (like AWQ)
53
+ export_command = export_command_base + int4_compression_args
54
+ print(f"Running compression command: {export_command}")
55
+
56
+ # Execute the export command (this is typically done locally, not in Hugging Face Spaces)
57
+ # For deployment, the model needs to be pre-compressed and uploaded
58
+ os.system(export_command)
59
+
60
+ # Check if the INT4 model exists or needs conversion
61
+ check_and_convert_model()
62
+
63
+ # Initialize OpenVINO model
64
+ core = ov.Core()
65
+ ov_model = OVModelForCausalLM.from_pretrained(
66
+ str(int4_model_dir),
67
+ device="CPU", # Adjust device as needed (e.g., "GPU" or "CPU")
68
+ config=AutoConfig.from_pretrained(str(int4_model_dir), trust_remote_code=True),
69
+ trust_remote_code=True,
70
+ )
71
+
72
+ def convert_history_to_token(history):
73
+ """
74
+ Convert the history of the conversation into tokens for the model.
75
+ """
76
+ input_ids = tok.encode(history[-1][0]) # Example tokenization
77
+ return torch.LongTensor([input_ids])
78
+
79
+ def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
80
+ """
81
+ Bot logic to process conversation history and generate responses.
82
+ """
83
+ input_ids = convert_history_to_token(history)
84
+ streamer = TextIteratorStreamer(tok, timeout=3600.0, skip_prompt=True, skip_special_tokens=True)
85
+ generate_kwargs = dict(
86
+ input_ids=input_ids,
87
+ max_new_tokens=256,
88
+ temperature=temperature,
89
+ do_sample=temperature > 0.0,
90
+ top_p=top_p,
91
+ top_k=top_k,
92
+ repetition_penalty=repetition_penalty,
93
+ streamer=streamer,
94
+ )
95
+
96
+ # Generate response
97
+ ov_model.generate(**generate_kwargs)
98
+
99
+ # Stream and update history with generated response
100
+ partial_text = ""
101
+ for new_text in streamer:
102
+ partial_text += new_text
103
+ history[-1][1] = partial_text
104
+ yield history
105
+
106
+ # Gradio interface setup
107
+ demo = make_demo(run_fn=bot, stop_fn=None, title="OpenVINO Chatbot", language="en")
108
+ demo.launch(debug=True, share=True)
gradio_helper.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Literal
2
+ import gradio as gr
3
+ from uuid import uuid4
4
+
5
+
6
+ chinese_examples = [
7
+ ["你好!"],
8
+ ["你是谁?"],
9
+ ["请介绍一下上海"],
10
+ ["请介绍一下英特尔公司"],
11
+ ["晚上睡不着怎么办?"],
12
+ ["给我讲一个年轻人奋斗创业最终取得成功的故事。"],
13
+ ["给这个故事起一个标题。"],
14
+ ]
15
+
16
+ english_examples = [
17
+ ["Hello there! How are you doing?"],
18
+ ["What is OpenVINO?"],
19
+ ["Who are you?"],
20
+ ["Can you explain to me briefly what is Python programming language?"],
21
+ ["Explain the plot of Cinderella in a sentence."],
22
+ ["What are some common mistakes to avoid when writing code?"],
23
+ ["Write a 100-word blog post on “Benefits of Artificial Intelligence and OpenVINO“"],
24
+ ]
25
+
26
+ japanese_examples = [
27
+ ["こんにちは!調子はどうですか?"],
28
+ ["OpenVINOとは何ですか?"],
29
+ ["あなたは誰ですか?"],
30
+ ["Pythonプログラミング言語とは何か簡単に説明してもらえますか?"],
31
+ ["シンデレラのあらすじを一文で説明してください。"],
32
+ ["コードを書くときに避けるべきよくある間違いは何ですか?"],
33
+ ["人工知能と「OpenVINOの利点」について100語程度のブログ記事を書いてください。"],
34
+ ]
35
+
36
+
37
+ def get_uuid():
38
+ """
39
+ universal unique identifier for thread
40
+ """
41
+ return str(uuid4())
42
+
43
+
44
+ def handle_user_message(message, history):
45
+ """
46
+ callback function for updating user messages in interface on submit button click
47
+
48
+ Params:
49
+ message: current message
50
+ history: conversation history
51
+ Returns:
52
+ None
53
+ """
54
+ # Append the user's message to the conversation history
55
+ return "", history + [[message, ""]]
56
+
57
+
58
+ def make_demo(
59
+ run_fn: Callable,
60
+ stop_fn: Callable,
61
+ title: str = "OpenVINO Chatbot",
62
+ language: Literal["English", "Chinese", "Japanese"] = "English"
63
+ ):
64
+ # Define examples based on the selected language
65
+ examples = (
66
+ chinese_examples if language == "Chinese"
67
+ else japanese_examples if language == "Japanese"
68
+ else english_examples
69
+ )
70
+
71
+ with gr.Blocks(
72
+ theme=gr.themes.Soft(),
73
+ css=".disclaimer {font-variant-caps: all-small-caps;}"
74
+ ) as demo:
75
+ conversation_id = gr.State(get_uuid) # Ensure get_uuid is defined elsewhere
76
+ gr.Markdown(f"<h1><center>{title}</center></h1>")
77
+ chatbot = gr.Chatbot(height=500)
78
+
79
+ # User message input
80
+ with gr.Row():
81
+ with gr.Column():
82
+ msg = gr.Textbox(
83
+ label="Chat Message Box",
84
+ placeholder="Chat Message Box",
85
+ show_label=False,
86
+ container=False,
87
+ )
88
+ with gr.Column():
89
+ submit = gr.Button("Submit")
90
+ stop = gr.Button("Stop")
91
+ clear = gr.Button("Clear")
92
+
93
+ # Advanced options for the chat
94
+ with gr.Row():
95
+ with gr.Accordion("Advanced Options:", open=False):
96
+ temperature = gr.Slider(
97
+ label="Temperature",
98
+ value=0.1,
99
+ minimum=0.0,
100
+ maximum=1.0,
101
+ step=0.1,
102
+ interactive=True,
103
+ info="Higher values produce more diverse outputs",
104
+ )
105
+ top_p = gr.Slider(
106
+ label="Top-p (nucleus sampling)",
107
+ value=1.0,
108
+ minimum=0.0,
109
+ maximum=1.0,
110
+ step=0.01,
111
+ interactive=True,
112
+ info=("Sample from the smallest possible set of tokens whose cumulative probability exceeds top_p. "
113
+ "Set to 1 to disable and sample from all tokens."),
114
+ )
115
+ top_k = gr.Slider(
116
+ label="Top-k",
117
+ value=50,
118
+ minimum=0,
119
+ maximum=200,
120
+ step=1,
121
+ interactive=True,
122
+ info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
123
+ )
124
+ repetition_penalty = gr.Slider(
125
+ label="Repetition Penalty",
126
+ value=1.1,
127
+ minimum=1.0,
128
+ maximum=2.0,
129
+ step=0.1,
130
+ interactive=True,
131
+ info="Penalize repetition — 1.0 to disable.",
132
+ )
133
+
134
+ # Example messages
135
+ gr.Examples(examples, inputs=msg, label="Click on any example and press the 'Submit' button")
136
+
137
+ # Submit message event
138
+ submit_event = msg.submit(
139
+ fn=handle_user_message,
140
+ inputs=[msg, chatbot],
141
+ outputs=[msg, chatbot],
142
+ queue=False,
143
+ ).then(
144
+ fn=run_fn,
145
+ inputs=[chatbot, temperature, top_p, top_k, repetition_penalty, conversation_id],
146
+ outputs=chatbot,
147
+ queue=True,
148
+ )
149
+
150
+ # Submit button click event
151
+ submit.click(
152
+ fn=handle_user_message,
153
+ inputs=[msg, chatbot],
154
+ outputs=[msg, chatbot],
155
+ queue=False,
156
+ ).then(
157
+ fn=run_fn,
158
+ inputs=[chatbot, temperature, top_p, top_k, repetition_penalty, conversation_id],
159
+ outputs=chatbot,
160
+ queue=True,
161
+ )
162
+
163
+ # Stop button functionality
164
+ stop.click(
165
+ fn=stop_fn,
166
+ inputs=None,
167
+ outputs=None,
168
+ cancels=[submit_event], # Cancels the submission event
169
+ queue=False,
170
+ )
171
+
172
+ # Clear chat button functionality
173
+ clear.click(lambda: None, None, chatbot, queue=False)
174
+
175
+ return demo
llm_config.py ADDED
@@ -0,0 +1,785 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT_SYSTEM_PROMPT = """\
2
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
3
+ If a question does not make any sense or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
4
+ """
5
+
6
+ DEFAULT_SYSTEM_PROMPT_CHINESE = """\
7
+ 你是一个乐于助人、尊重他人以及诚实可靠的助手。在安全的情况下,始终尽可能有帮助地回答。 您的回答不应包含任何有害、不道德、种族主义、性别歧视、有毒、危险或非法的内容。请确保您的回答在社会上是公正的和积极的。
8
+ 如果一个问题没有任何意义或与事实不符,请解释原因,而不是回答错误的问题。如果您不知道问题的答案,请不要分享虚假信息。另外,答案请使用中文。\
9
+ """
10
+
11
+ DEFAULT_SYSTEM_PROMPT_JAPANESE = """\
12
+ あなたは親切で、礼儀正しく、誠実なアシスタントです。 常に安全を保ちながら、できるだけ役立つように答えてください。 回答には、有害、非倫理的、人種差別的、性差別的、有毒、危険、または違法なコンテンツを含めてはいけません。 回答は社会的に偏見がなく、本質的に前向きなものであることを確認してください。
13
+ 質問が意味をなさない場合、または事実に一貫性がない場合は、正しくないことに答えるのではなく、その理由を説明してください。 質問の答えがわからない場合は、誤った情報を共有しないでください。\
14
+ """
15
+
16
+ DEFAULT_RAG_PROMPT = """\
17
+ You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\
18
+ """
19
+
20
+ DEFAULT_RAG_PROMPT_CHINESE = """\
21
+ 基于以下已知信息,请简洁并专业地回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。不允许在答案中添加编造成分。另外,答案请使用中文。\
22
+ """
23
+
24
+
25
+ def red_pijama_partial_text_processor(partial_text, new_text):
26
+ if new_text == "<":
27
+ return partial_text
28
+
29
+ partial_text += new_text
30
+ return partial_text.split("<bot>:")[-1]
31
+
32
+
33
+ def llama_partial_text_processor(partial_text, new_text):
34
+ new_text = new_text.replace("[INST]", "").replace("[/INST]", "")
35
+ partial_text += new_text
36
+ return partial_text
37
+
38
+
39
+ def chatglm_partial_text_processor(partial_text, new_text):
40
+ new_text = new_text.strip()
41
+ new_text = new_text.replace("[[训练时间]]", "2023年")
42
+ partial_text += new_text
43
+ return partial_text
44
+
45
+
46
+ def youri_partial_text_processor(partial_text, new_text):
47
+ new_text = new_text.replace("システム:", "")
48
+ partial_text += new_text
49
+ return partial_text
50
+
51
+
52
+ def internlm_partial_text_processor(partial_text, new_text):
53
+ partial_text += new_text
54
+ return partial_text.split("<|im_end|>")[0]
55
+
56
+
57
+ def phi_completion_to_prompt(completion):
58
+ return f"<|system|><|end|><|user|>{completion}<|end|><|assistant|>\n"
59
+
60
+
61
+ def llama3_completion_to_prompt(completion):
62
+ return f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{completion}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
63
+
64
+
65
+ def qwen_completion_to_prompt(completion):
66
+ return f"<|im_start|>system\n<|im_end|>\n<|im_start|>user\n{completion}<|im_end|>\n<|im_start|>assistant\n"
67
+
68
+
69
+ SUPPORTED_LLM_MODELS = {
70
+ "English": {
71
+ "qwen2.5-0.5b-instruct": {
72
+ "model_id": "Qwen/Qwen2.5-0.5B-Instruct",
73
+ "remote_code": False,
74
+ "start_message": DEFAULT_SYSTEM_PROMPT,
75
+ "stop_tokens": ["<|im_end|>", "<|endoftext|>"],
76
+ "completion_to_prompt": qwen_completion_to_prompt,
77
+ },
78
+ "tiny-llama-1b-chat": {
79
+ "model_id": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
80
+ "remote_code": False,
81
+ "start_message": f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}</s>\n",
82
+ "history_template": "<|user|>\n{user}</s> \n<|assistant|>\n{assistant}</s> \n",
83
+ "current_message_template": "<|user|>\n{user}</s> \n<|assistant|>\n{assistant}",
84
+ "rag_prompt_template": f"""<|system|> {DEFAULT_RAG_PROMPT }</s>"""
85
+ + """
86
+ <|user|>
87
+ Question: {input}
88
+ Context: {context}
89
+ Answer: </s>
90
+ <|assistant|>""",
91
+ },
92
+ "llama-3.2-1b-instruct": {
93
+ "model_id": "meta-llama/Llama-3.2-1B-Instruct",
94
+ "start_message": DEFAULT_SYSTEM_PROMPT,
95
+ "stop_tokens": ["<|eot_id|>"],
96
+ "has_chat_template": True,
97
+ "start_message": " <|start_header_id|>system<|end_header_id|>\n\n" + DEFAULT_SYSTEM_PROMPT + "<|eot_id|>",
98
+ "history_template": "<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{assistant}<|eot_id|>",
99
+ "current_message_template": "<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{assistant}",
100
+ "rag_prompt_template": f"<|start_header_id|>system<|end_header_id|>\n\n{DEFAULT_RAG_PROMPT}<|eot_id|>"
101
+ + """<|start_header_id|>user<|end_header_id|>
102
+
103
+
104
+ Question: {input}
105
+ Context: {context}
106
+ Answer:<|eot_id|><|start_header_id|>assistant<|end_header_id|>
107
+
108
+
109
+ """,
110
+ "completion_to_prompt": llama3_completion_to_prompt,
111
+ },
112
+ "llama-3.2-3b-instruct": {
113
+ "model_id": "meta-llama/Llama-3.2-3B-Instruct",
114
+ "start_message": DEFAULT_SYSTEM_PROMPT,
115
+ "stop_tokens": ["<|eot_id|>"],
116
+ "has_chat_template": True,
117
+ "start_message": " <|start_header_id|>system<|end_header_id|>\n\n" + DEFAULT_SYSTEM_PROMPT + "<|eot_id|>",
118
+ "history_template": "<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{assistant}<|eot_id|>",
119
+ "current_message_template": "<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{assistant}",
120
+ "rag_prompt_template": f"<|start_header_id|>system<|end_header_id|>\n\n{DEFAULT_RAG_PROMPT}<|eot_id|>"
121
+ + """<|start_header_id|>user<|end_header_id|>
122
+
123
+
124
+ Question: {input}
125
+ Context: {context}
126
+ Answer:<|eot_id|><|start_header_id|>assistant<|end_header_id|>
127
+
128
+
129
+ """,
130
+ "completion_to_prompt": llama3_completion_to_prompt,
131
+ },
132
+ "qwen2.5-1.5b-instruct": {
133
+ "model_id": "Qwen/Qwen2.5-1.5B-Instruct",
134
+ "remote_code": False,
135
+ "start_message": DEFAULT_SYSTEM_PROMPT,
136
+ "stop_tokens": ["<|im_end|>", "<|endoftext|>"],
137
+ "completion_to_prompt": qwen_completion_to_prompt,
138
+ },
139
+ "gemma-2b-it": {
140
+ "model_id": "google/gemma-2b-it",
141
+ "remote_code": False,
142
+ "start_message": DEFAULT_SYSTEM_PROMPT + ", ",
143
+ "history_template": "<start_of_turn>user{user}<end_of_turn><start_of_turn>model{assistant}<end_of_turn>",
144
+ "current_message_template": "<start_of_turn>user{user}<end_of_turn><start_of_turn>model{assistant}",
145
+ "rag_prompt_template": f"""{DEFAULT_RAG_PROMPT},"""
146
+ + """<start_of_turn>user{input}<end_of_turn><start_of_turn>context{context}<end_of_turn><start_of_turn>model""",
147
+ },
148
+ "gemma-2-2b-it": {
149
+ "model_id": "google/gemma-2-2b-it",
150
+ "remote_code": False,
151
+ "start_message": DEFAULT_SYSTEM_PROMPT + ", ",
152
+ "history_template": "<start_of_turn>user{user}<end_of_turn><start_of_turn>model{assistant}<end_of_turn>",
153
+ "current_message_template": "<start_of_turn>user{user}<end_of_turn><start_of_turn>model{assistant}",
154
+ "rag_prompt_template": f"""{DEFAULT_RAG_PROMPT},"""
155
+ + """<start_of_turn>user{input}<end_of_turn><start_of_turn>context{context}<end_of_turn><start_of_turn>model""",
156
+ },
157
+ "red-pajama-3b-chat": {
158
+ "model_id": "togethercomputer/RedPajama-INCITE-Chat-3B-v1",
159
+ "remote_code": False,
160
+ "start_message": "",
161
+ "history_template": "\n<human>:{user}\n<bot>:{assistant}",
162
+ "stop_tokens": [29, 0],
163
+ "partial_text_processor": red_pijama_partial_text_processor,
164
+ "current_message_template": "\n<human>:{user}\n<bot>:{assistant}",
165
+ "rag_prompt_template": f"""{DEFAULT_RAG_PROMPT }"""
166
+ + """
167
+ <human>: Question: {input}
168
+ Context: {context}
169
+ Answer: <bot>""",
170
+ },
171
+ "qwen2.5-3b-instruct": {
172
+ "model_id": "Qwen/Qwen2.5-3B-Instruct",
173
+ "remote_code": False,
174
+ "start_message": DEFAULT_SYSTEM_PROMPT + ", ",
175
+ "rag_prompt_template": f"""<|im_start|>system
176
+ {DEFAULT_RAG_PROMPT }<|im_end|>"""
177
+ + """
178
+ <|im_start|>user
179
+ Question: {input}
180
+ Context: {context}
181
+ Answer: <|im_end|>
182
+ <|im_start|>assistant
183
+ """,
184
+ "completion_to_prompt": qwen_completion_to_prompt,
185
+ },
186
+ "qwen2.5-7b-instruct": {
187
+ "model_id": "Qwen/Qwen2.5-7B-Instruct",
188
+ "remote_code": False,
189
+ "start_message": DEFAULT_SYSTEM_PROMPT + ", ",
190
+ "rag_prompt_template": f"""<|im_start|>system
191
+ {DEFAULT_RAG_PROMPT }<|im_end|>"""
192
+ + """
193
+ <|im_start|>user
194
+ Question: {input}
195
+ Context: {context}
196
+ Answer: <|im_end|>
197
+ <|im_start|>assistant
198
+ """,
199
+ "completion_to_prompt": qwen_completion_to_prompt,
200
+ },
201
+ "gemma-7b-it": {
202
+ "model_id": "google/gemma-7b-it",
203
+ "remote_code": False,
204
+ "start_message": DEFAULT_SYSTEM_PROMPT + ", ",
205
+ "history_template": "<start_of_turn>user{user}<end_of_turn><start_of_turn>model{assistant}<end_of_turn>",
206
+ "current_message_template": "<start_of_turn>user{user}<end_of_turn><start_of_turn>model{assistant}",
207
+ "rag_prompt_template": f"""{DEFAULT_RAG_PROMPT},"""
208
+ + """<start_of_turn>user{input}<end_of_turn><start_of_turn>context{context}<end_of_turn><start_of_turn>model""",
209
+ },
210
+ "gemma-2-9b-it": {
211
+ "model_id": "google/gemma-2-9b-it",
212
+ "remote_code": False,
213
+ "start_message": DEFAULT_SYSTEM_PROMPT + ", ",
214
+ "history_template": "<start_of_turn>user{user}<end_of_turn><start_of_turn>model{assistant}<end_of_turn>",
215
+ "current_message_template": "<start_of_turn>user{user}<end_of_turn><start_of_turn>model{assistant}",
216
+ "rag_prompt_template": f"""{DEFAULT_RAG_PROMPT},"""
217
+ + """<start_of_turn>user{input}<end_of_turn><start_of_turn>context{context}<end_of_turn><start_of_turn>model""",
218
+ },
219
+ "llama-2-chat-7b": {
220
+ "model_id": "meta-llama/Llama-2-7b-chat-hf",
221
+ "remote_code": False,
222
+ "start_message": f"<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT }\n<</SYS>>\n\n",
223
+ "history_template": "{user}[/INST]{assistant}</s><s>[INST]",
224
+ "current_message_template": "{user} [/INST]{assistant}",
225
+ "tokenizer_kwargs": {"add_special_tokens": False},
226
+ "partial_text_processor": llama_partial_text_processor,
227
+ "rag_prompt_template": f"""[INST]Human: <<SYS>> {DEFAULT_RAG_PROMPT }<</SYS>>"""
228
+ + """
229
+ Question: {input}
230
+ Context: {context}
231
+ Answer: [/INST]""",
232
+ },
233
+ "llama-3-8b-instruct": {
234
+ "model_id": "meta-llama/Meta-Llama-3-8B-Instruct",
235
+ "remote_code": False,
236
+ "start_message": DEFAULT_SYSTEM_PROMPT,
237
+ "stop_tokens": ["<|eot_id|>", "<|end_of_text|>"],
238
+ "has_chat_template": True,
239
+ "start_message": " <|start_header_id|>system<|end_header_id|>\n\n" + DEFAULT_SYSTEM_PROMPT + "<|eot_id|>",
240
+ "history_template": "<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{assistant}<|eot_id|>",
241
+ "current_message_template": "<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{assistant}",
242
+ "rag_prompt_template": f"<|start_header_id|>system<|end_header_id|>\n\n{DEFAULT_RAG_PROMPT}<|eot_id|>"
243
+ + """<|start_header_id|>user<|end_header_id|>
244
+
245
+
246
+ Question: {input}
247
+ Context: {context}
248
+ Answer:<|eot_id|><|start_header_id|>assistant<|end_header_id|>
249
+
250
+
251
+ """,
252
+ "completion_to_prompt": llama3_completion_to_prompt,
253
+ },
254
+ "llama-3.1-8b-instruct": {
255
+ "model_id": "meta-llama/Meta-Llama-3.1-8B-Instruct",
256
+ "remote_code": False,
257
+ "start_message": DEFAULT_SYSTEM_PROMPT,
258
+ "stop_tokens": ["<|eot_id|>", "<|end_of_text|>"],
259
+ "has_chat_template": True,
260
+ "start_message": " <|start_header_id|>system<|end_header_id|>\n\n" + DEFAULT_SYSTEM_PROMPT + "<|eot_id|>",
261
+ "history_template": "<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{assistant}<|eot_id|>",
262
+ "current_message_template": "<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{assistant}",
263
+ "rag_prompt_template": f"<|start_header_id|>system<|end_header_id|>\n\n{DEFAULT_RAG_PROMPT}<|eot_id|>"
264
+ + """<|start_header_id|>user<|end_header_id|>
265
+
266
+
267
+ Question: {input}
268
+ Context: {context}
269
+ Answer:<|eot_id|><|start_header_id|>assistant<|end_header_id|>
270
+
271
+
272
+ """,
273
+ "completion_to_prompt": llama3_completion_to_prompt,
274
+ },
275
+ "mistral-7b-instruct": {
276
+ "model_id": "mistralai/Mistral-7B-Instruct-v0.1",
277
+ "remote_code": False,
278
+ "start_message": f"<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT }\n<</SYS>>\n\n",
279
+ "history_template": "{user}[/INST]{assistant}</s><s>[INST]",
280
+ "current_message_template": "{user} [/INST]{assistant}",
281
+ "tokenizer_kwargs": {"add_special_tokens": False},
282
+ "partial_text_processor": llama_partial_text_processor,
283
+ "rag_prompt_template": f"""<s> [INST] {DEFAULT_RAG_PROMPT } [/INST] </s>"""
284
+ + """
285
+ [INST] Question: {input}
286
+ Context: {context}
287
+ Answer: [/INST]""",
288
+ },
289
+ "zephyr-7b-beta": {
290
+ "model_id": "HuggingFaceH4/zephyr-7b-beta",
291
+ "remote_code": False,
292
+ "start_message": f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}</s>\n",
293
+ "history_template": "<|user|>\n{user}</s> \n<|assistant|>\n{assistant}</s> \n",
294
+ "current_message_template": "<|user|>\n{user}</s> \n<|assistant|>\n{assistant}",
295
+ "rag_prompt_template": f"""<|system|> {DEFAULT_RAG_PROMPT }</s>"""
296
+ + """
297
+ <|user|>
298
+ Question: {input}
299
+ Context: {context}
300
+ Answer: </s>
301
+ <|assistant|>""",
302
+ },
303
+ "notus-7b-v1": {
304
+ "model_id": "argilla/notus-7b-v1",
305
+ "remote_code": False,
306
+ "start_message": f"<|system|>\n{DEFAULT_SYSTEM_PROMPT}</s>\n",
307
+ "history_template": "<|user|>\n{user}</s> \n<|assistant|>\n{assistant}</s> \n",
308
+ "current_message_template": "<|user|>\n{user}</s> \n<|assistant|>\n{assistant}",
309
+ "rag_prompt_template": f"""<|system|> {DEFAULT_RAG_PROMPT }</s>"""
310
+ + """
311
+ <|user|>
312
+ Question: {input}
313
+ Context: {context}
314
+ Answer: </s>
315
+ <|assistant|>""",
316
+ },
317
+ "neural-chat-7b-v3-3": {
318
+ "model_id": "Intel/neural-chat-7b-v3-3",
319
+ "remote_code": False,
320
+ "start_message": f"<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT }\n<</SYS>>\n\n",
321
+ "history_template": "{user}[/INST]{assistant}</s><s>[INST]",
322
+ "current_message_template": "{user} [/INST]{assistant}",
323
+ "tokenizer_kwargs": {"add_special_tokens": False},
324
+ "partial_text_processor": llama_partial_text_processor,
325
+ "rag_prompt_template": f"""<s> [INST] {DEFAULT_RAG_PROMPT } [/INST] </s>"""
326
+ + """
327
+ [INST] Question: {input}
328
+ Context: {context}
329
+ Answer: [/INST]""",
330
+ },
331
+ "phi-3-mini-instruct": {
332
+ "model_id": "microsoft/Phi-3-mini-4k-instruct",
333
+ "remote_code": True,
334
+ "start_message": "<|system|>\n{DEFAULT_SYSTEM_PROMPT}<|end|>\n",
335
+ "history_template": "<|user|>\n{user}<|end|> \n<|assistant|>\n{assistant}<|end|>\n",
336
+ "current_message_template": "<|user|>\n{user}<|end|> \n<|assistant|>\n{assistant}",
337
+ "stop_tokens": ["<|end|>"],
338
+ "rag_prompt_template": f"""<|system|> {DEFAULT_RAG_PROMPT }<|end|>"""
339
+ + """
340
+ <|user|>
341
+ Question: {input}
342
+ Context: {context}
343
+ Answer: <|end|>
344
+ <|assistant|>""",
345
+ "completion_to_prompt": phi_completion_to_prompt,
346
+ },
347
+ "phi-3.5-mini-instruct": {
348
+ "model_id": "microsoft/Phi-3.5-mini-instruct",
349
+ "remote_code": True,
350
+ "start_message": "<|system|>\n{DEFAULT_SYSTEM_PROMPT}<|end|>\n",
351
+ "history_template": "<|user|>\n{user}<|end|> \n<|assistant|>\n{assistant}<|end|>\n",
352
+ "current_message_template": "<|user|>\n{user}<|end|> \n<|assistant|>\n{assistant}",
353
+ "stop_tokens": ["<|end|>"],
354
+ "rag_prompt_template": f"""<|system|> {DEFAULT_RAG_PROMPT }<|end|>"""
355
+ + """
356
+ <|user|>
357
+ Question: {input}
358
+ Context: {context}
359
+ Answer: <|end|>
360
+ <|assistant|>""",
361
+ "completion_to_prompt": phi_completion_to_prompt,
362
+ },
363
+ "qwen2.5-14b-instruct": {
364
+ "model_id": "Qwen/Qwen2.5-14B-Instruct",
365
+ "remote_code": False,
366
+ "start_message": DEFAULT_SYSTEM_PROMPT + ", ",
367
+ "rag_prompt_template": f"""<|im_start|>system
368
+ {DEFAULT_RAG_PROMPT }<|im_end|>"""
369
+ + """
370
+ <|im_start|>user
371
+ Question: {input}
372
+ Context: {context}
373
+ Answer: <|im_end|>
374
+ <|im_start|>assistant
375
+ """,
376
+ "completion_to_prompt": qwen_completion_to_prompt,
377
+ },
378
+ },
379
+ "Chinese": {
380
+ "qwen2.5-0.5b-instruct": {
381
+ "model_id": "Qwen/Qwen2.5-0.5B-Instruct",
382
+ "remote_code": False,
383
+ "start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
384
+ "stop_tokens": ["<|im_end|>", "<|endoftext|>"],
385
+ "completion_to_prompt": qwen_completion_to_prompt,
386
+ },
387
+ "qwen2.5-1.5b-instruct": {
388
+ "model_id": "Qwen/Qwen2.5-1.5B-Instruct",
389
+ "remote_code": False,
390
+ "start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
391
+ "stop_tokens": ["<|im_end|>", "<|endoftext|>"],
392
+ "completion_to_prompt": qwen_completion_to_prompt,
393
+ },
394
+ "qwen2.5-3b-instruct": {
395
+ "model_id": "Qwen/Qwen2.5-3B-Instruct",
396
+ "remote_code": False,
397
+ "start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
398
+ "stop_tokens": ["<|im_end|>", "<|endoftext|>"],
399
+ "completion_to_prompt": qwen_completion_to_prompt,
400
+ },
401
+ "qwen2.5-7b-instruct": {
402
+ "model_id": "Qwen/Qwen2.5-7B-Instruct",
403
+ "remote_code": False,
404
+ "stop_tokens": ["<|im_end|>", "<|endoftext|>"],
405
+ "start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
406
+ "rag_prompt_template": f"""<|im_start|>system
407
+ {DEFAULT_RAG_PROMPT_CHINESE }<|im_end|>"""
408
+ + """
409
+ <|im_start|>user
410
+ 问题: {input}
411
+ 已知内容: {context}
412
+ 回答: <|im_end|>
413
+ <|im_start|>assistant
414
+ """,
415
+ "completion_to_prompt": qwen_completion_to_prompt,
416
+ },
417
+ "qwen2.5-14b-instruct": {
418
+ "model_id": "Qwen/Qwen2.5-14B-Instruct",
419
+ "remote_code": False,
420
+ "stop_tokens": ["<|im_end|>", "<|endoftext|>"],
421
+ "start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
422
+ "rag_prompt_template": f"""<|im_start|>system
423
+ {DEFAULT_RAG_PROMPT_CHINESE }<|im_end|>"""
424
+ + """
425
+ <|im_start|>user
426
+ 问题: {input}
427
+ 已知内容: {context}
428
+ 回答: <|im_end|>
429
+ <|im_start|>assistant
430
+ """,
431
+ "completion_to_prompt": qwen_completion_to_prompt,
432
+ },
433
+ "qwen-7b-chat": {
434
+ "model_id": "Qwen/Qwen-7B-Chat",
435
+ "remote_code": True,
436
+ "start_message": f"<|im_start|>system\n {DEFAULT_SYSTEM_PROMPT_CHINESE }<|im_end|>",
437
+ "history_template": "<|im_start|>user\n{user}<im_end><|im_start|>assistant\n{assistant}<|im_end|>",
438
+ "current_message_template": '"<|im_start|>user\n{user}<im_end><|im_start|>assistant\n{assistant}',
439
+ "stop_tokens": ["<|im_end|>", "<|endoftext|>"],
440
+ "revision": "2abd8e5777bb4ce9c8ab4be7dbbd0fe4526db78d",
441
+ "rag_prompt_template": f"""<|im_start|>system
442
+ {DEFAULT_RAG_PROMPT_CHINESE }<|im_end|>"""
443
+ + """
444
+ <|im_start|>user
445
+ 问题: {input}
446
+ 已知内容: {context}
447
+ 回答: <|im_end|>
448
+ <|im_start|>assistant
449
+ """,
450
+ },
451
+ "chatglm3-6b": {
452
+ "model_id": "THUDM/chatglm3-6b",
453
+ "remote_code": True,
454
+ "start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
455
+ "tokenizer_kwargs": {"add_special_tokens": False},
456
+ "rag_prompt_template": f"""{DEFAULT_RAG_PROMPT_CHINESE }"""
457
+ + """
458
+ 问题: {input}
459
+ 已知内容: {context}
460
+ 回答:
461
+ """,
462
+ },
463
+ "glm-4-9b-chat": {
464
+ "model_id": "THUDM/glm-4-9b-chat",
465
+ "remote_code": True,
466
+ "start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
467
+ "tokenizer_kwargs": {"add_special_tokens": False},
468
+ "rag_prompt_template": f"""{DEFAULT_RAG_PROMPT_CHINESE }"""
469
+ + """
470
+ 问题: {input}
471
+ 已知内容: {context}
472
+ 回答:
473
+ """,
474
+ },
475
+ "baichuan2-7b-chat": {
476
+ "model_id": "baichuan-inc/Baichuan2-7B-Chat",
477
+ "remote_code": True,
478
+ "start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
479
+ "tokenizer_kwargs": {"add_special_tokens": False},
480
+ "stop_tokens": ["<unk>", "</s>"],
481
+ "rag_prompt_template": f"""{DEFAULT_RAG_PROMPT_CHINESE }"""
482
+ + """
483
+ 问题: {input}
484
+ 已知内容: {context}
485
+ 回答:
486
+ """,
487
+ },
488
+ "minicpm-2b-dpo": {
489
+ "model_id": "openbmb/MiniCPM-2B-dpo-fp16",
490
+ "remote_code": True,
491
+ "start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
492
+ },
493
+ "internlm2-chat-1.8b": {
494
+ "model_id": "internlm/internlm2-chat-1_8b",
495
+ "remote_code": True,
496
+ "start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
497
+ "stop_tokens": ["</s>", "<|im_end|>"],
498
+ "partial_text_processor": internlm_partial_text_processor,
499
+ },
500
+ "qwen1.5-1.8b-chat": {
501
+ "model_id": "Qwen/Qwen1.5-1.8B-Chat",
502
+ "remote_code": False,
503
+ "start_message": DEFAULT_SYSTEM_PROMPT_CHINESE,
504
+ "stop_tokens": ["<|im_end|>", "<|endoftext|>"],
505
+ "rag_prompt_template": f"""<|im_start|>system
506
+ {DEFAULT_RAG_PROMPT_CHINESE }<|im_end|>"""
507
+ + """
508
+ <|im_start|>user
509
+ 问题: {input}
510
+ 已知内容: {context}
511
+ 回答: <|im_end|>
512
+ <|im_start|>assistant
513
+ """,
514
+ },
515
+ },
516
+ "Japanese": {
517
+ "youri-7b-chat": {
518
+ "model_id": "rinna/youri-7b-chat",
519
+ "remote_code": False,
520
+ "start_message": f"設定: {DEFAULT_SYSTEM_PROMPT_JAPANESE}\n",
521
+ "history_template": "ユーザー: {user}\nシステム: {assistant}\n",
522
+ "current_message_template": "ユーザー: {user}\nシステム: {assistant}",
523
+ "tokenizer_kwargs": {"add_special_tokens": False},
524
+ "partial_text_processor": youri_partial_text_processor,
525
+ },
526
+ },
527
+ }
528
+
529
+ SUPPORTED_EMBEDDING_MODELS = {
530
+ "English": {
531
+ "bge-small-en-v1.5": {
532
+ "model_id": "BAAI/bge-small-en-v1.5",
533
+ "mean_pooling": False,
534
+ "normalize_embeddings": True,
535
+ },
536
+ "bge-large-en-v1.5": {
537
+ "model_id": "BAAI/bge-large-en-v1.5",
538
+ "mean_pooling": False,
539
+ "normalize_embeddings": True,
540
+ },
541
+ "bge-m3": {
542
+ "model_id": "BAAI/bge-m3",
543
+ "mean_pooling": False,
544
+ "normalize_embeddings": True,
545
+ },
546
+ },
547
+ "Chinese": {
548
+ "bge-small-zh-v1.5": {
549
+ "model_id": "BAAI/bge-small-zh-v1.5",
550
+ "mean_pooling": False,
551
+ "normalize_embeddings": True,
552
+ },
553
+ "bge-large-zh-v1.5": {
554
+ "model_id": "BAAI/bge-large-zh-v1.5",
555
+ "mean_pooling": False,
556
+ "normalize_embeddings": True,
557
+ },
558
+ "bge-m3": {
559
+ "model_id": "BAAI/bge-m3",
560
+ "mean_pooling": False,
561
+ "normalize_embeddings": True,
562
+ },
563
+ },
564
+ }
565
+
566
+
567
+ SUPPORTED_RERANK_MODELS = {
568
+ "bge-reranker-v2-m3": {"model_id": "BAAI/bge-reranker-v2-m3"},
569
+ "bge-reranker-large": {"model_id": "BAAI/bge-reranker-large"},
570
+ "bge-reranker-base": {"model_id": "BAAI/bge-reranker-base"},
571
+ }
572
+
573
+ compression_configs = {
574
+ "zephyr-7b-beta": {
575
+ "sym": True,
576
+ "group_size": 64,
577
+ "ratio": 0.6,
578
+ },
579
+ "mistral-7b": {
580
+ "sym": True,
581
+ "group_size": 64,
582
+ "ratio": 0.6,
583
+ },
584
+ "minicpm-2b-dpo": {
585
+ "sym": True,
586
+ "group_size": 64,
587
+ "ratio": 0.6,
588
+ },
589
+ "gemma-2b-it": {
590
+ "sym": True,
591
+ "group_size": 64,
592
+ "ratio": 0.6,
593
+ },
594
+ "notus-7b-v1": {
595
+ "sym": True,
596
+ "group_size": 64,
597
+ "ratio": 0.6,
598
+ },
599
+ "neural-chat-7b-v3-1": {
600
+ "sym": True,
601
+ "group_size": 64,
602
+ "ratio": 0.6,
603
+ },
604
+ "llama-2-chat-7b": {
605
+ "sym": True,
606
+ "group_size": 128,
607
+ "ratio": 0.8,
608
+ },
609
+ "llama-3-8b-instruct": {
610
+ "sym": True,
611
+ "group_size": 128,
612
+ "ratio": 0.8,
613
+ },
614
+ "gemma-7b-it": {
615
+ "sym": True,
616
+ "group_size": 128,
617
+ "ratio": 0.8,
618
+ },
619
+ "chatglm2-6b": {
620
+ "sym": True,
621
+ "group_size": 128,
622
+ "ratio": 0.72,
623
+ },
624
+ "qwen-7b-chat": {"sym": True, "group_size": 128, "ratio": 0.6},
625
+ "qwen2.5-7b-instruct": {"sym": True, "group_size": 128, "ratio": 1.0},
626
+ "qwen2.5-3b-instruct": {"sym": True, "group_size": 128, "ratio": 1.0},
627
+ "qwen2.5-14b-instruct": {"sym": True, "group_size": 128, "ratio": 1.0},
628
+ "qwen2.5-1.5b-instruct": {"sym": True, "group_size": 128, "ratio": 1.0},
629
+ "qwen2.5-0.5b-instruct": {"sym": True, "group_size": 128, "ratio": 1.0},
630
+ "red-pajama-3b-chat": {
631
+ "sym": False,
632
+ "group_size": 128,
633
+ "ratio": 0.5,
634
+ },
635
+ "llama-3.2-3b-instruct": {"sym": False, "group_size": 64, "ratio": 1.0, "dataset": "wikitext2", "awq": True, "all_layers": True, "scale_estimation": True},
636
+ "llama-3.2-1b-instruct": {"sym": False, "group_size": 64, "ratio": 1.0, "dataset": "wikitext2", "awq": True, "all_layers": True, "scale_estimation": True},
637
+ "default": {
638
+ "sym": False,
639
+ "group_size": 128,
640
+ "ratio": 0.8,
641
+ },
642
+ }
643
+
644
+
645
+ def get_optimum_cli_command(model_id, weight_format, output_dir, compression_options=None, enable_awq=False, trust_remote_code=False):
646
+ base_command = "optimum-cli export openvino --model {} --task text-generation-with-past --weight-format {}"
647
+ command = base_command.format(model_id, weight_format)
648
+ if compression_options:
649
+ compression_args = " --group-size {} --ratio {}".format(compression_options["group_size"], compression_options["ratio"])
650
+ if compression_options["sym"]:
651
+ compression_args += " --sym"
652
+ if enable_awq or compression_options.get("awq", False):
653
+ compression_args += " --awq --dataset wikitext2 --num-samples 128"
654
+ if compression_options.get("scale_estimation", False):
655
+ compression_args += " --scale-estimation"
656
+ if compression_options.get("all_layers", False):
657
+ compression_args += " --all-layers"
658
+
659
+ command = command + compression_args
660
+ if trust_remote_code:
661
+ command += " --trust-remote-code"
662
+
663
+ command += " {}".format(output_dir)
664
+ return command
665
+
666
+
667
+ default_language = "English"
668
+
669
+ SUPPORTED_OPTIMIZATIONS = ["INT4", "INT4-AWQ", "INT8", "FP16"]
670
+
671
+
672
+ def get_llm_selection_widget(languages=list(SUPPORTED_LLM_MODELS), models=SUPPORTED_LLM_MODELS[default_language], show_preconverted_checkbox=True):
673
+ import ipywidgets as widgets
674
+
675
+ lang_dropdown = widgets.Dropdown(options=languages or [])
676
+
677
+ # Define dependent drop down
678
+
679
+ model_dropdown = widgets.Dropdown(options=models)
680
+
681
+ def dropdown_handler(change):
682
+ global default_language
683
+ default_language = change.new
684
+ # If statement checking on dropdown value and changing options of the dependent dropdown accordingly
685
+ model_dropdown.options = SUPPORTED_LLM_MODELS[change.new]
686
+
687
+ lang_dropdown.observe(dropdown_handler, names="value")
688
+ compression_dropdown = widgets.Dropdown(options=SUPPORTED_OPTIMIZATIONS)
689
+ preconverted_checkbox = widgets.Checkbox(value=True)
690
+
691
+ form_items = []
692
+
693
+ if languages:
694
+ form_items.append(widgets.Box([widgets.Label(value="Language:"), lang_dropdown]))
695
+ form_items.extend(
696
+ [
697
+ widgets.Box([widgets.Label(value="Model:"), model_dropdown]),
698
+ widgets.Box([widgets.Label(value="Compression:"), compression_dropdown]),
699
+ ]
700
+ )
701
+ if show_preconverted_checkbox:
702
+ form_items.append(widgets.Box([widgets.Label(value="Use preconverted models:"), preconverted_checkbox]))
703
+
704
+ form = widgets.Box(
705
+ form_items,
706
+ layout=widgets.Layout(
707
+ display="flex",
708
+ flex_flow="column",
709
+ border="solid 1px",
710
+ # align_items='stretch',
711
+ width="30%",
712
+ padding="1%",
713
+ ),
714
+ )
715
+ return form, lang_dropdown, model_dropdown, compression_dropdown, preconverted_checkbox
716
+
717
+
718
+ def convert_tokenizer(model_id, remote_code, model_dir):
719
+ import openvino as ov
720
+ from transformers import AutoTokenizer
721
+ from openvino_tokenizers import convert_tokenizer
722
+
723
+ hf_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=remote_code)
724
+ ov_tokenizer, ov_detokenizer = convert_tokenizer(hf_tokenizer, with_detokenizer=True)
725
+ ov.save_model(ov_tokenizer, model_dir / "openvino_tokenizer.xml")
726
+ ov.save_model(ov_detokenizer, model_dir / "openvino_detokenizer.xml")
727
+
728
+
729
+ def convert_and_compress_model(model_id, model_config, precision, use_preconverted=True):
730
+ from pathlib import Path
731
+ from IPython.display import Markdown, display
732
+ import subprocess # nosec - disable B404:import-subprocess check
733
+ import platform
734
+
735
+ pt_model_id = model_config["model_id"]
736
+ pt_model_name = model_id.split("-")[0]
737
+ model_subdir = precision if precision == "FP16" else precision + "_compressed_weights"
738
+ model_dir = Path(pt_model_name) / model_subdir
739
+ remote_code = model_config.get("remote_code", False)
740
+ if (model_dir / "openvino_model.xml").exists():
741
+ print(f"✅ {precision} {model_id} model already converted and can be found in {model_dir}")
742
+
743
+ if not (model_dir / "openvino_tokenizer.xml").exists() or not (model_dir / "openvino_detokenizer.xml").exists():
744
+ convert_tokenizer(pt_model_id, remote_code, model_dir)
745
+ return model_dir
746
+ if use_preconverted:
747
+ OV_ORG = "OpenVINO"
748
+ pt_model_name = pt_model_id.split("/")[-1]
749
+ ov_model_name = pt_model_name + f"-{precision.lower()}-ov"
750
+ ov_model_hub_id = f"{OV_ORG}/{ov_model_name}"
751
+ import huggingface_hub as hf_hub
752
+
753
+ hub_api = hf_hub.HfApi()
754
+ if hub_api.repo_exists(ov_model_hub_id):
755
+ print(f"⌛Found preconverted {precision} {model_id}. Downloading model started. It may takes some time.")
756
+ hf_hub.snapshot_download(ov_model_hub_id, local_dir=model_dir)
757
+ print(f"✅ {precision} {model_id} model downloaded and can be found in {model_dir}")
758
+ return model_dir
759
+
760
+ model_compression_params = {}
761
+ if "INT4" in precision:
762
+ model_compression_params = compression_configs.get(model_id, compression_configs["default"])
763
+ weight_format = precision.split("-")[0].lower()
764
+ optimum_cli_command = get_optimum_cli_command(pt_model_id, weight_format, model_dir, model_compression_params, "AWQ" in precision, remote_code)
765
+ print(f"⌛ {model_id} conversion to {precision} started. It may takes some time.")
766
+ display(Markdown("**Export command:**"))
767
+ display(Markdown(f"`{optimum_cli_command}`"))
768
+ subprocess.run(optimum_cli_command.split(" "), shell=(platform.system() == "Windows"), check=True)
769
+ print(f"✅ {precision} {model_id} model converted and can be found in {model_dir}")
770
+ return model_dir
771
+
772
+
773
+ def compare_model_size(model_dir):
774
+ fp16_weights = model_dir.parent / "FP16" / "openvino_model.bin"
775
+ int8_weights = model_dir.parent / "INT8_compressed_weights" / "openvino_model.bin"
776
+ int4_weights = model_dir.parent / "INT4_compressed_weights" / "openvino_model.bin"
777
+ int4_awq_weights = model_dir.parent / "INT4-AWQ_compressed_weights" / "openvino_model.bin"
778
+
779
+ if fp16_weights.exists():
780
+ print(f"Size of FP16 model is {fp16_weights.stat().st_size / 1024 / 1024:.2f} MB")
781
+ for precision, compressed_weights in zip(["INT8", "INT4", "INT4-AWQ"], [int8_weights, int4_weights, int4_awq_weights]):
782
+ if compressed_weights.exists():
783
+ print(f"Size of model with {precision} compressed weights is {compressed_weights.stat().st_size / 1024 / 1024:.2f} MB")
784
+ if compressed_weights.exists() and fp16_weights.exists():
785
+ print(f"Compression rate for {precision} model: {fp16_weights.stat().st_size / compressed_weights.stat().st_size:.3f}")
notebook_utils.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[ ]:
5
+
6
+
7
+ import os
8
+ import platform
9
+ import sys
10
+ import threading
11
+ import time
12
+ import urllib.parse
13
+ from os import PathLike
14
+ from pathlib import Path
15
+ from typing import List, NamedTuple, Optional, Tuple
16
+
17
+ import numpy as np
18
+ from openvino.runtime import Core, Type, get_version
19
+ from IPython.display import HTML, Image, display
20
+
21
+ import openvino as ov
22
+ from openvino.runtime.passes import Manager, MatcherPass, WrapType, Matcher
23
+ from openvino.runtime import opset10 as ops
24
+
25
+
26
+ # ## Files
27
+ #
28
+ # Load an image, download a file, download an IR model, and create a progress bar to show download progress.
29
+
30
+ # In[ ]:
31
+
32
+
33
+ def device_widget(default="AUTO", exclude=None, added=None):
34
+ import openvino as ov
35
+ import ipywidgets as widgets
36
+
37
+ core = ov.Core()
38
+
39
+ supported_devices = core.available_devices + ["AUTO"]
40
+ exclude = exclude or []
41
+ if exclude:
42
+ for ex_device in exclude:
43
+ if ex_device in supported_devices:
44
+ supported_devices.remove(ex_device)
45
+
46
+ added = added or []
47
+ if added:
48
+ for add_device in added:
49
+ if add_device not in supported_devices:
50
+ supported_devices.append(add_device)
51
+
52
+ device = widgets.Dropdown(
53
+ options=supported_devices,
54
+ value=default,
55
+ description="Device:",
56
+ disabled=False,
57
+ )
58
+ return device
59
+
60
+
61
+ def quantization_widget(default=True):
62
+ import ipywidgets as widgets
63
+
64
+ to_quantize = widgets.Checkbox(
65
+ value=default,
66
+ description="Quantization",
67
+ disabled=False,
68
+ )
69
+
70
+ return to_quantize
71
+
72
+
73
+ def pip_install(*args):
74
+ import subprocess # nosec - disable B404:import-subprocess check
75
+
76
+ cli_args = []
77
+ for arg in args:
78
+ cli_args.extend(str(arg).split(" "))
79
+ subprocess.run([sys.executable, "-m", "pip", "install", *cli_args], shell=(platform.system() == "Windows"), check=True)
80
+
81
+
82
+ def load_image(path: str) -> np.ndarray:
83
+ """
84
+ Loads an image from `path` and returns it as BGR numpy array. `path`
85
+ should point to an image file, either a local filename or a url. The image is
86
+ not stored to the filesystem. Use the `download_file` function to download and
87
+ store an image.
88
+
89
+ :param path: Local path name or URL to image.
90
+ :return: image as BGR numpy array
91
+ """
92
+ import cv2
93
+ import requests
94
+
95
+ if path.startswith("http"):
96
+ # Set User-Agent to Mozilla because some websites block
97
+ # requests with User-Agent Python
98
+ response = requests.get(path, headers={"User-Agent": "Mozilla/5.0"})
99
+ array = np.asarray(bytearray(response.content), dtype="uint8")
100
+ image = cv2.imdecode(array, -1) # Loads the image as BGR
101
+ else:
102
+ image = cv2.imread(path)
103
+ return image
104
+
105
+
106
+ def download_file(
107
+ url: PathLike,
108
+ filename: PathLike = None,
109
+ directory: PathLike = None,
110
+ show_progress: bool = True,
111
+ silent: bool = False,
112
+ timeout: int = 10,
113
+ ) -> PathLike:
114
+ """
115
+ Download a file from a url and save it to the local filesystem. The file is saved to the
116
+ current directory by default, or to `directory` if specified. If a filename is not given,
117
+ the filename of the URL will be used.
118
+
119
+ :param url: URL that points to the file to download
120
+ :param filename: Name of the local file to save. Should point to the name of the file only,
121
+ not the full path. If None the filename from the url will be used
122
+ :param directory: Directory to save the file to. Will be created if it doesn't exist
123
+ If None the file will be saved to the current working directory
124
+ :param show_progress: If True, show an TQDM ProgressBar
125
+ :param silent: If True, do not print a message if the file already exists
126
+ :param timeout: Number of seconds before cancelling the connection attempt
127
+ :return: path to downloaded file
128
+ """
129
+ from tqdm.notebook import tqdm_notebook
130
+ import requests
131
+
132
+ filename = filename or Path(urllib.parse.urlparse(url).path).name
133
+ chunk_size = 16384 # make chunks bigger so that not too many updates are triggered for Jupyter front-end
134
+
135
+ filename = Path(filename)
136
+ if len(filename.parts) > 1:
137
+ raise ValueError(
138
+ "`filename` should refer to the name of the file, excluding the directory. "
139
+ "Use the `directory` parameter to specify a target directory for the downloaded file."
140
+ )
141
+
142
+ # create the directory if it does not exist, and add the directory to the filename
143
+ if directory is not None:
144
+ directory = Path(directory)
145
+ directory.mkdir(parents=True, exist_ok=True)
146
+ filename = directory / Path(filename)
147
+
148
+ try:
149
+ response = requests.get(url=url, headers={"User-agent": "Mozilla/5.0"}, stream=True)
150
+ response.raise_for_status()
151
+ except (
152
+ requests.exceptions.HTTPError
153
+ ) as error: # For error associated with not-200 codes. Will output something like: "404 Client Error: Not Found for url: {url}"
154
+ raise Exception(error) from None
155
+ except requests.exceptions.Timeout:
156
+ raise Exception(
157
+ "Connection timed out. If you access the internet through a proxy server, please "
158
+ "make sure the proxy is set in the shell from where you launched Jupyter."
159
+ ) from None
160
+ except requests.exceptions.RequestException as error:
161
+ raise Exception(f"File downloading failed with error: {error}") from None
162
+
163
+ # download the file if it does not exist, or if it exists with an incorrect file size
164
+ filesize = int(response.headers.get("Content-length", 0))
165
+ if not filename.exists() or (os.stat(filename).st_size != filesize):
166
+ with tqdm_notebook(
167
+ total=filesize,
168
+ unit="B",
169
+ unit_scale=True,
170
+ unit_divisor=1024,
171
+ desc=str(filename),
172
+ disable=not show_progress,
173
+ ) as progress_bar:
174
+ with open(filename, "wb") as file_object:
175
+ for chunk in response.iter_content(chunk_size):
176
+ file_object.write(chunk)
177
+ progress_bar.update(len(chunk))
178
+ progress_bar.refresh()
179
+ else:
180
+ if not silent:
181
+ print(f"'{filename}' already exists.")
182
+
183
+ response.close()
184
+
185
+ return filename.resolve()
186
+
187
+
188
+ def download_ir_model(model_xml_url: str, destination_folder: PathLike = None) -> PathLike:
189
+ """
190
+ Download IR model from `model_xml_url`. Downloads model xml and bin file; the weights file is
191
+ assumed to exist at the same location and name as model_xml_url with a ".bin" extension.
192
+
193
+ :param model_xml_url: URL to model xml file to download
194
+ :param destination_folder: Directory where downloaded model xml and bin are saved. If None, model
195
+ files are saved to the current directory
196
+ :return: path to downloaded xml model file
197
+ """
198
+ model_bin_url = model_xml_url[:-4] + ".bin"
199
+ model_xml_path = download_file(model_xml_url, directory=destination_folder, show_progress=False)
200
+ download_file(model_bin_url, directory=destination_folder)
201
+ return model_xml_path
202
+
203
+
204
+ # ## Images
205
+
206
+ # ### Convert Pixel Data
207
+ #
208
+ # Normalize image pixel values between 0 and 1, and convert images to RGB and BGR.
209
+
210
+ # In[ ]:
211
+
212
+
213
+ def normalize_minmax(data):
214
+ """
215
+ Normalizes the values in `data` between 0 and 1
216
+ """
217
+ if data.max() == data.min():
218
+ raise ValueError("Normalization is not possible because all elements of" f"`data` have the same value: {data.max()}.")
219
+ return (data - data.min()) / (data.max() - data.min())
220
+
221
+
222
+ def to_rgb(image_data: np.ndarray) -> np.ndarray:
223
+ """
224
+ Convert image_data from BGR to RGB
225
+ """
226
+ import cv2
227
+
228
+ return cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
229
+
230
+
231
+ def to_bgr(image_data: np.ndarray) -> np.ndarray:
232
+ """
233
+ Convert image_data from RGB to BGR
234
+ """
235
+ import cv2
236
+
237
+ return cv2.cvtColor(image_data, cv2.COLOR_RGB2BGR)
238
+
239
+
240
+ # ## Videos
241
+
242
+ # ### Video Player
243
+ #
244
+ # Custom video player to fulfill FPS requirements. You can set target FPS and output size, flip the video horizontally or skip first N frames.
245
+
246
+ # In[ ]:
247
+
248
+
249
+ class VideoPlayer:
250
+ """
251
+ Custom video player to fulfill FPS requirements. You can set target FPS and output size,
252
+ flip the video horizontally or skip first N frames.
253
+
254
+ :param source: Video source. It could be either camera device or video file.
255
+ :param size: Output frame size.
256
+ :param flip: Flip source horizontally.
257
+ :param fps: Target FPS.
258
+ :param skip_first_frames: Skip first N frames.
259
+ """
260
+
261
+ def __init__(self, source, size=None, flip=False, fps=None, skip_first_frames=0, width=1280, height=720):
262
+ import cv2
263
+
264
+ self.cv2 = cv2 # This is done to access the package in class methods
265
+ self.__cap = cv2.VideoCapture(source)
266
+ # try HD by default to get better video quality
267
+ self.__cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
268
+ self.__cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
269
+
270
+ if not self.__cap.isOpened():
271
+ raise RuntimeError(f"Cannot open {'camera' if isinstance(source, int) else ''} {source}")
272
+ # skip first N frames
273
+ self.__cap.set(cv2.CAP_PROP_POS_FRAMES, skip_first_frames)
274
+ # fps of input file
275
+ self.__input_fps = self.__cap.get(cv2.CAP_PROP_FPS)
276
+ if self.__input_fps <= 0:
277
+ self.__input_fps = 60
278
+ # target fps given by user
279
+ self.__output_fps = fps if fps is not None else self.__input_fps
280
+ self.__flip = flip
281
+ self.__size = None
282
+ self.__interpolation = None
283
+ if size is not None:
284
+ self.__size = size
285
+ # AREA better for shrinking, LINEAR better for enlarging
286
+ self.__interpolation = cv2.INTER_AREA if size[0] < self.__cap.get(cv2.CAP_PROP_FRAME_WIDTH) else cv2.INTER_LINEAR
287
+ # first frame
288
+ _, self.__frame = self.__cap.read()
289
+ self.__lock = threading.Lock()
290
+ self.__thread = None
291
+ self.__stop = False
292
+
293
+ """
294
+ Start playing.
295
+ """
296
+
297
+ def start(self):
298
+ self.__stop = False
299
+ self.__thread = threading.Thread(target=self.__run, daemon=True)
300
+ self.__thread.start()
301
+
302
+ """
303
+ Stop playing and release resources.
304
+ """
305
+
306
+ def stop(self):
307
+ self.__stop = True
308
+ if self.__thread is not None:
309
+ self.__thread.join()
310
+ self.__cap.release()
311
+
312
+ def __run(self):
313
+ prev_time = 0
314
+ while not self.__stop:
315
+ t1 = time.time()
316
+ ret, frame = self.__cap.read()
317
+ if not ret:
318
+ break
319
+
320
+ # fulfill target fps
321
+ if 1 / self.__output_fps < time.time() - prev_time:
322
+ prev_time = time.time()
323
+ # replace by current frame
324
+ with self.__lock:
325
+ self.__frame = frame
326
+
327
+ t2 = time.time()
328
+ # time to wait [s] to fulfill input fps
329
+ wait_time = 1 / self.__input_fps - (t2 - t1)
330
+ # wait until
331
+ time.sleep(max(0, wait_time))
332
+
333
+ self.__frame = None
334
+
335
+ """
336
+ Get current frame.
337
+ """
338
+
339
+ def next(self):
340
+ import cv2
341
+
342
+ with self.__lock:
343
+ if self.__frame is None:
344
+ return None
345
+ # need to copy frame, because can be cached and reused if fps is low
346
+ frame = self.__frame.copy()
347
+ if self.__size is not None:
348
+ frame = self.cv2.resize(frame, self.__size, interpolation=self.__interpolation)
349
+ if self.__flip:
350
+ frame = self.cv2.flip(frame, 1)
351
+ return frame
352
+
353
+
354
+ # ## Visualization
355
+
356
+ # ### Segmentation
357
+ #
358
+ # Define a SegmentationMap NamedTuple that keeps the labels and colormap for a segmentation project/dataset. Create CityScapesSegmentation and BinarySegmentation SegmentationMaps. Create a function to convert a segmentation map to an RGB image with a colormap, and to show the segmentation result as an overlay over the original image.
359
+
360
+ # In[ ]:
361
+
362
+
363
+ class Label(NamedTuple):
364
+ index: int
365
+ color: Tuple
366
+ name: Optional[str] = None
367
+
368
+
369
+ # In[ ]:
370
+
371
+
372
+ class SegmentationMap(NamedTuple):
373
+ labels: List
374
+
375
+ def get_colormap(self):
376
+ return np.array([label.color for label in self.labels])
377
+
378
+ def get_labels(self):
379
+ labelnames = [label.name for label in self.labels]
380
+ if any(labelnames):
381
+ return labelnames
382
+ else:
383
+ return None
384
+
385
+
386
+ # In[ ]:
387
+
388
+
389
+ cityscape_labels = [
390
+ Label(index=0, color=(128, 64, 128), name="road"),
391
+ Label(index=1, color=(244, 35, 232), name="sidewalk"),
392
+ Label(index=2, color=(70, 70, 70), name="building"),
393
+ Label(index=3, color=(102, 102, 156), name="wall"),
394
+ Label(index=4, color=(190, 153, 153), name="fence"),
395
+ Label(index=5, color=(153, 153, 153), name="pole"),
396
+ Label(index=6, color=(250, 170, 30), name="traffic light"),
397
+ Label(index=7, color=(220, 220, 0), name="traffic sign"),
398
+ Label(index=8, color=(107, 142, 35), name="vegetation"),
399
+ Label(index=9, color=(152, 251, 152), name="terrain"),
400
+ Label(index=10, color=(70, 130, 180), name="sky"),
401
+ Label(index=11, color=(220, 20, 60), name="person"),
402
+ Label(index=12, color=(255, 0, 0), name="rider"),
403
+ Label(index=13, color=(0, 0, 142), name="car"),
404
+ Label(index=14, color=(0, 0, 70), name="truck"),
405
+ Label(index=15, color=(0, 60, 100), name="bus"),
406
+ Label(index=16, color=(0, 80, 100), name="train"),
407
+ Label(index=17, color=(0, 0, 230), name="motorcycle"),
408
+ Label(index=18, color=(119, 11, 32), name="bicycle"),
409
+ Label(index=19, color=(255, 255, 255), name="background"),
410
+ ]
411
+
412
+ CityScapesSegmentation = SegmentationMap(cityscape_labels)
413
+
414
+ binary_labels = [
415
+ Label(index=0, color=(255, 255, 255), name="background"),
416
+ Label(index=1, color=(0, 0, 0), name="foreground"),
417
+ ]
418
+
419
+ BinarySegmentation = SegmentationMap(binary_labels)
420
+
421
+
422
+ # In[ ]:
423
+
424
+
425
+ def segmentation_map_to_image(result: np.ndarray, colormap: np.ndarray, remove_holes: bool = False) -> np.ndarray:
426
+ """
427
+ Convert network result of floating point numbers to an RGB image with
428
+ integer values from 0-255 by applying a colormap.
429
+
430
+ :param result: A single network result after converting to pixel values in H,W or 1,H,W shape.
431
+ :param colormap: A numpy array of shape (num_classes, 3) with an RGB value per class.
432
+ :param remove_holes: If True, remove holes in the segmentation result.
433
+ :return: An RGB image where each pixel is an int8 value according to colormap.
434
+ """
435
+ import cv2
436
+
437
+ if len(result.shape) != 2 and result.shape[0] != 1:
438
+ raise ValueError(f"Expected result with shape (H,W) or (1,H,W), got result with shape {result.shape}")
439
+
440
+ if len(np.unique(result)) > colormap.shape[0]:
441
+ raise ValueError(
442
+ f"Expected max {colormap[0]} classes in result, got {len(np.unique(result))} "
443
+ "different output values. Please make sure to convert the network output to "
444
+ "pixel values before calling this function."
445
+ )
446
+ elif result.shape[0] == 1:
447
+ result = result.squeeze(0)
448
+
449
+ result = result.astype(np.uint8)
450
+
451
+ contour_mode = cv2.RETR_EXTERNAL if remove_holes else cv2.RETR_TREE
452
+ mask = np.zeros((result.shape[0], result.shape[1], 3), dtype=np.uint8)
453
+ for label_index, color in enumerate(colormap):
454
+ label_index_map = result == label_index
455
+ label_index_map = label_index_map.astype(np.uint8) * 255
456
+ contours, hierarchies = cv2.findContours(label_index_map, contour_mode, cv2.CHAIN_APPROX_SIMPLE)
457
+ cv2.drawContours(
458
+ mask,
459
+ contours,
460
+ contourIdx=-1,
461
+ color=color.tolist(),
462
+ thickness=cv2.FILLED,
463
+ )
464
+
465
+ return mask
466
+
467
+
468
+ def segmentation_map_to_overlay(image, result, alpha, colormap, remove_holes=False) -> np.ndarray:
469
+ """
470
+ Returns a new image where a segmentation mask (created with colormap) is overlayed on
471
+ the source image.
472
+
473
+ :param image: Source image.
474
+ :param result: A single network result after converting to pixel values in H,W or 1,H,W shape.
475
+ :param alpha: Alpha transparency value for the overlay image.
476
+ :param colormap: A numpy array of shape (num_classes, 3) with an RGB value per class.
477
+ :param remove_holes: If True, remove holes in the segmentation result.
478
+ :return: An RGP image with segmentation mask overlayed on the source image.
479
+ """
480
+ import cv2
481
+
482
+ if len(image.shape) == 2:
483
+ image = np.repeat(np.expand_dims(image, -1), 3, 2)
484
+ mask = segmentation_map_to_image(result, colormap, remove_holes)
485
+ image_height, image_width = image.shape[:2]
486
+ mask = cv2.resize(src=mask, dsize=(image_width, image_height))
487
+ return cv2.addWeighted(mask, alpha, image, 1 - alpha, 0)
488
+
489
+
490
+ # ### Network Results
491
+ #
492
+ # Show network result image, optionally together with the source image and a legend with labels.
493
+
494
+ # In[ ]:
495
+
496
+
497
+ def viz_result_image(
498
+ result_image: np.ndarray,
499
+ source_image: np.ndarray = None,
500
+ source_title: str = None,
501
+ result_title: str = None,
502
+ labels: List[Label] = None,
503
+ resize: bool = False,
504
+ bgr_to_rgb: bool = False,
505
+ hide_axes: bool = False,
506
+ ):
507
+ """
508
+ Show result image, optionally together with source images, and a legend with labels.
509
+
510
+ :param result_image: Numpy array of RGB result image.
511
+ :param source_image: Numpy array of source image. If provided this image will be shown
512
+ next to the result image. source_image is expected to be in RGB format.
513
+ Set bgr_to_rgb to True if source_image is in BGR format.
514
+ :param source_title: Title to display for the source image.
515
+ :param result_title: Title to display for the result image.
516
+ :param labels: List of labels. If provided, a legend will be shown with the given labels.
517
+ :param resize: If true, resize the result image to the same shape as the source image.
518
+ :param bgr_to_rgb: If true, convert the source image from BGR to RGB. Use this option if
519
+ source_image is a BGR image.
520
+ :param hide_axes: If true, do not show matplotlib axes.
521
+ :return: Matplotlib figure with result image
522
+ """
523
+ import cv2
524
+ import matplotlib.pyplot as plt
525
+ from matplotlib.lines import Line2D
526
+
527
+ if bgr_to_rgb:
528
+ source_image = to_rgb(source_image)
529
+ if resize:
530
+ result_image = cv2.resize(result_image, (source_image.shape[1], source_image.shape[0]))
531
+
532
+ num_images = 1 if source_image is None else 2
533
+
534
+ fig, ax = plt.subplots(1, num_images, figsize=(16, 8), squeeze=False)
535
+ if source_image is not None:
536
+ ax[0, 0].imshow(source_image)
537
+ ax[0, 0].set_title(source_title)
538
+
539
+ ax[0, num_images - 1].imshow(result_image)
540
+ ax[0, num_images - 1].set_title(result_title)
541
+
542
+ if hide_axes:
543
+ for a in ax.ravel():
544
+ a.axis("off")
545
+ if labels:
546
+ colors = labels.get_colormap()
547
+ lines = [
548
+ Line2D(
549
+ [0],
550
+ [0],
551
+ color=[item / 255 for item in c.tolist()],
552
+ linewidth=3,
553
+ linestyle="-",
554
+ )
555
+ for c in colors
556
+ ]
557
+ plt.legend(
558
+ lines,
559
+ labels.get_labels(),
560
+ bbox_to_anchor=(1, 1),
561
+ loc="upper left",
562
+ prop={"size": 12},
563
+ )
564
+ plt.close(fig)
565
+ return fig
566
+
567
+
568
+ # ### Live Inference
569
+
570
+ # In[ ]:
571
+
572
+
573
+ def show_array(frame: np.ndarray, display_handle=None):
574
+ """
575
+ Display array `frame`. Replace information at `display_handle` with `frame`
576
+ encoded as jpeg image. `frame` is expected to have data in BGR order.
577
+
578
+ Create a display_handle with: `display_handle = display(display_id=True)`
579
+ """
580
+ import cv2
581
+
582
+ _, frame = cv2.imencode(ext=".jpeg", img=frame)
583
+ if display_handle is None:
584
+ display_handle = display(Image(data=frame.tobytes()), display_id=True)
585
+ else:
586
+ display_handle.update(Image(data=frame.tobytes()))
587
+ return display_handle
588
+
589
+
590
+ # ## Checks and Alerts
591
+ #
592
+ # Create an alert class to show stylized info/error/warning messages and a `check_device` function that checks whether a given device is available.
593
+
594
+ # In[ ]:
595
+
596
+
597
+ class NotebookAlert(Exception):
598
+ def __init__(self, message: str, alert_class: str):
599
+ """
600
+ Show an alert box with the given message.
601
+
602
+ :param message: The message to display.
603
+ :param alert_class: The class for styling the message. Options: info, warning, success, danger.
604
+ """
605
+ self.message = message
606
+ self.alert_class = alert_class
607
+ self.show_message()
608
+
609
+ def show_message(self):
610
+ display(HTML(f"""<div class="alert alert-{self.alert_class}">{self.message}"""))
611
+
612
+
613
+ class DeviceNotFoundAlert(NotebookAlert):
614
+ def __init__(self, device: str):
615
+ """
616
+ Show a warning message about an unavailable device. This class does not check whether or
617
+ not the device is available, use the `check_device` function to check this. `check_device`
618
+ also shows the warning if the device is not found.
619
+
620
+ :param device: The unavailable device.
621
+ :return: A formatted alert box with the message that `device` is not available, and a list
622
+ of devices that are available.
623
+ """
624
+ ie = Core()
625
+ supported_devices = ie.available_devices
626
+ self.message = f"Running this cell requires a {device} device, " "which is not available on this system. "
627
+ self.alert_class = "warning"
628
+ if len(supported_devices) == 1:
629
+ self.message += f"The following device is available: {ie.available_devices[0]}"
630
+ else:
631
+ self.message += "The following devices are available: " f"{', '.join(ie.available_devices)}"
632
+ super().__init__(self.message, self.alert_class)
633
+
634
+
635
+ def check_device(device: str) -> bool:
636
+ """
637
+ Check if the specified device is available on the system.
638
+
639
+ :param device: Device to check. e.g. CPU, GPU
640
+ :return: True if the device is available, False if not. If the device is not available,
641
+ a DeviceNotFoundAlert will be shown.
642
+ """
643
+ ie = Core()
644
+ if device not in ie.available_devices:
645
+ DeviceNotFoundAlert(device)
646
+ return False
647
+ else:
648
+ return True
649
+
650
+
651
+ def check_openvino_version(version: str) -> bool:
652
+ """
653
+ Check if the specified OpenVINO version is installed.
654
+
655
+ :param version: the OpenVINO version to check. Example: 2021.4
656
+ :return: True if the version is installed, False if not. If the version is not installed,
657
+ an alert message will be shown.
658
+ """
659
+ installed_version = get_version()
660
+ if version not in installed_version:
661
+ NotebookAlert(
662
+ f"This notebook requires OpenVINO {version}. "
663
+ f"The version on your system is: <i>{installed_version}</i>.<br>"
664
+ "Please run <span style='font-family:monospace'>pip install --upgrade -r requirements.txt</span> "
665
+ "in the openvino_env environment to install this version. "
666
+ "See the <a href='https://github.com/openvinotoolkit/openvino_notebooks'>"
667
+ "OpenVINO Notebooks README</a> for detailed instructions",
668
+ alert_class="danger",
669
+ )
670
+ return False
671
+ else:
672
+ return True
673
+
674
+
675
+ packed_layername_tensor_dict_list = [{"name": "aten::mul/Multiply"}]
676
+
677
+
678
+ class ReplaceTensor(MatcherPass):
679
+ def __init__(self, packed_layername_tensor_dict_list):
680
+ MatcherPass.__init__(self)
681
+ self.model_changed = False
682
+
683
+ param = WrapType("opset10.Multiply")
684
+
685
+ def callback(matcher: Matcher) -> bool:
686
+ root = matcher.get_match_root()
687
+ if root is None:
688
+ return False
689
+ for y in packed_layername_tensor_dict_list:
690
+ root_name = root.get_friendly_name()
691
+ if root_name.find(y["name"]) != -1:
692
+ max_fp16 = np.array([[[[-np.finfo(np.float16).max]]]]).astype(np.float32)
693
+ new_tenser = ops.constant(max_fp16, Type.f32, name="Constant_4431")
694
+ root.set_arguments([root.input_value(0).node, new_tenser])
695
+ packed_layername_tensor_dict_list.remove(y)
696
+
697
+ return True
698
+
699
+ self.register_matcher(Matcher(param, "ReplaceTensor"), callback)
700
+
701
+
702
+ def optimize_bge_embedding(model_path, output_model_path):
703
+ """
704
+ optimize_bge_embedding used to optimize BGE model for NPU device
705
+
706
+ Arguments:
707
+ model_path {str} -- original BGE IR model path
708
+ output_model_path {str} -- Converted BGE IR model path
709
+ """
710
+ core = Core()
711
+ ov_model = core.read_model(model_path)
712
+ manager = Manager()
713
+ manager.register_pass(ReplaceTensor(packed_layername_tensor_dict_list))
714
+ manager.run_passes(ov_model)
715
+ ov.save_model(ov_model, output_model_path, compress_to_fp16=False)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openvino>=2024.2.0
2
+ openvino-tokenizers[transformers]
3
+ torch>=2.1
4
+ datasets
5
+ accelerate
6
+ gradio>=4.19
7
+ onnx<=1.16.1; sys_platform=='win32'
8
+ einops
9
+ transformers>=4.43.1
10
+ transformers_stream_generator
11
+ tiktoken
12
+ bitsandbytes
13
+ optimum-intel @ git+https://github.com/huggingface/optimum-intel.git
14
+ nncf @ git+https://github.com/openvinotoolkit/nncf.git