Update tokenization_minicpm.py
Browse files- tokenization_minicpm.py +10 -10
tokenization_minicpm.py
CHANGED
@@ -4,7 +4,6 @@ import keyword
|
|
4 |
import traceback
|
5 |
import uuid
|
6 |
from collections import deque
|
7 |
-
from copy import deepcopy
|
8 |
from logging import getLogger
|
9 |
from typing import Any, Dict, List, Optional, Union
|
10 |
|
@@ -17,6 +16,7 @@ from jsonschema import Draft202012Validator, exceptions, validate
|
|
17 |
from transformers import LlamaTokenizerFast
|
18 |
from transformers.tokenization_utils_base import BatchEncoding
|
19 |
from transformers.utils import TensorType
|
|
|
20 |
|
21 |
|
22 |
logger = getLogger(__name__)
|
@@ -148,7 +148,7 @@ class MiniCPMTokenizer(LlamaTokenizerFast):
|
|
148 |
tool_calls.append(this_one)
|
149 |
|
150 |
return {
|
151 |
-
"content": content
|
152 |
"tool_calls": [
|
153 |
{"type": "function", "function": tool_call, "id": "call_" + uuid.uuid4().hex}
|
154 |
for tool_call in tool_calls
|
@@ -158,13 +158,13 @@ class MiniCPMTokenizer(LlamaTokenizerFast):
|
|
158 |
except:
|
159 |
logger.error(traceback.format_exc())
|
160 |
return {
|
161 |
-
"content": content
|
162 |
"role": "assistant",
|
163 |
"thought": thought_string,
|
164 |
}
|
165 |
else:
|
166 |
return {
|
167 |
-
"content": sequence
|
168 |
"role": "assistant",
|
169 |
"thought": thought_string,
|
170 |
}
|
@@ -259,10 +259,11 @@ def message_format(msg, system_suffix="", user_prefix=""):
|
|
259 |
content = thought_prefix + content
|
260 |
msg["content"] = content
|
261 |
elif msg["role"] == "user":
|
262 |
-
|
|
|
263 |
elif msg["role"] == "system":
|
264 |
msg["content"] = msg["content"] + "\n" + system_suffix
|
265 |
-
msg["content"] = msg["content"]
|
266 |
return msg
|
267 |
|
268 |
|
@@ -361,12 +362,12 @@ func2(params)
|
|
361 |
<|tool_call_end|>
|
362 |
{{answer the user's question directly or ask the user for more information}}
|
363 |
"""
|
364 |
-
tools_string = tools_template.format(tools=tools_string)
|
365 |
else:
|
366 |
tools_string = ""
|
367 |
|
368 |
if add_to_system:
|
369 |
-
if len(messages) > 0 and messages[0]["role"] != "system" and tools_string
|
370 |
messages.insert(0, {"role": "system", "content": ""})
|
371 |
return [message_format(msg, system_suffix=tools_string, user_prefix="") for msg in messages]
|
372 |
else:
|
@@ -429,5 +430,4 @@ def resolve_ast_by_type(value):
|
|
429 |
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
|
430 |
else:
|
431 |
raise Exception(f"Unsupported AST type: {type(value)}")
|
432 |
-
return output
|
433 |
-
|
|
|
4 |
import traceback
|
5 |
import uuid
|
6 |
from collections import deque
|
|
|
7 |
from logging import getLogger
|
8 |
from typing import Any, Dict, List, Optional, Union
|
9 |
|
|
|
16 |
from transformers import LlamaTokenizerFast
|
17 |
from transformers.tokenization_utils_base import BatchEncoding
|
18 |
from transformers.utils import TensorType
|
19 |
+
from copy import deepcopy
|
20 |
|
21 |
|
22 |
logger = getLogger(__name__)
|
|
|
148 |
tool_calls.append(this_one)
|
149 |
|
150 |
return {
|
151 |
+
"content": content,
|
152 |
"tool_calls": [
|
153 |
{"type": "function", "function": tool_call, "id": "call_" + uuid.uuid4().hex}
|
154 |
for tool_call in tool_calls
|
|
|
158 |
except:
|
159 |
logger.error(traceback.format_exc())
|
160 |
return {
|
161 |
+
"content": content,
|
162 |
"role": "assistant",
|
163 |
"thought": thought_string,
|
164 |
}
|
165 |
else:
|
166 |
return {
|
167 |
+
"content": sequence,
|
168 |
"role": "assistant",
|
169 |
"thought": thought_string,
|
170 |
}
|
|
|
259 |
content = thought_prefix + content
|
260 |
msg["content"] = content
|
261 |
elif msg["role"] == "user":
|
262 |
+
if user_prefix != "":
|
263 |
+
msg["content"] = user_prefix + "\n" + msg["content"]
|
264 |
elif msg["role"] == "system":
|
265 |
msg["content"] = msg["content"] + "\n" + system_suffix
|
266 |
+
msg["content"] = msg["content"]
|
267 |
return msg
|
268 |
|
269 |
|
|
|
362 |
<|tool_call_end|>
|
363 |
{{answer the user's question directly or ask the user for more information}}
|
364 |
"""
|
365 |
+
tools_string = tools_template.format(tools=tools_string)
|
366 |
else:
|
367 |
tools_string = ""
|
368 |
|
369 |
if add_to_system:
|
370 |
+
if len(messages) > 0 and messages[0]["role"] != "system" and len(tools_string.strip()) > 0:
|
371 |
messages.insert(0, {"role": "system", "content": ""})
|
372 |
return [message_format(msg, system_suffix=tools_string, user_prefix="") for msg in messages]
|
373 |
else:
|
|
|
430 |
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
|
431 |
else:
|
432 |
raise Exception(f"Unsupported AST type: {type(value)}")
|
433 |
+
return output
|
|