Delete tokenizer_template_switch.py
Browse files- tokenizer_template_switch.py +0 -95
tokenizer_template_switch.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
from transformers import AutoTokenizer
|
3 |
-
|
4 |
-
def extract_separators(template):
|
5 |
-
"""
|
6 |
-
Extracts separators used in the tokenization template.
|
7 |
-
"""
|
8 |
-
# Adjust the regex to correctly match the specific pattern between '{{' and '+ message["content"] +'
|
9 |
-
pattern = r"\{\{\s*([^{}]+?)\s*\+ message\['content'\]"
|
10 |
-
matches = re.findall(pattern, template)
|
11 |
-
# Clean up any extra spaces and return the matches
|
12 |
-
separators = [match.strip() for match in matches]
|
13 |
-
|
14 |
-
if any("message['role']" in element for element in separators):
|
15 |
-
roles = ["system", "user", "assistant"]
|
16 |
-
separators_ = []
|
17 |
-
for role in roles:
|
18 |
-
separators_.append(separators[0].replace(" + message['role'] + ", role).replace("'",""))
|
19 |
-
return separators_
|
20 |
-
|
21 |
-
return separators
|
22 |
-
|
23 |
-
def detect_eos_token(jinja_template, tokenizer):
|
24 |
-
if "<|im_end|>" in jinja_template:
|
25 |
-
return "<|im_end|>"
|
26 |
-
if "</s>" in jinja_template:
|
27 |
-
return "</s>"
|
28 |
-
if "eos_token" in jinja_template:
|
29 |
-
return tokenizer.eos_token
|
30 |
-
else:
|
31 |
-
return "<|endoftext|>"
|
32 |
-
|
33 |
-
def recover_messages(formatted_message, separators, eos_token):
|
34 |
-
"""
|
35 |
-
Recovers the original messages from the formatted message string.
|
36 |
-
"""
|
37 |
-
# Split the formatted message using the end-of-string token
|
38 |
-
split_messages = formatted_message.split(eos_token)
|
39 |
-
|
40 |
-
# Remove the last empty string if it exists due to a trailing separator
|
41 |
-
if split_messages and split_messages[-1].strip() == '':
|
42 |
-
split_messages.pop()
|
43 |
-
|
44 |
-
# Prepare the list to hold the recovered messages
|
45 |
-
recovered_messages = []
|
46 |
-
|
47 |
-
# Define roles after the first message, alternating between "user" and "assistant"
|
48 |
-
alternate_roles = ["user", "assistant"]
|
49 |
-
|
50 |
-
# Iterate over the split messages
|
51 |
-
for index, message_content in enumerate(split_messages):
|
52 |
-
# Determine the role, starting with "system" for the first message
|
53 |
-
# then alternating between "user" and "assistant" for subsequent messages
|
54 |
-
if index == 0:
|
55 |
-
role = "system"
|
56 |
-
else:
|
57 |
-
role = alternate_roles[(index - 1) % 2]
|
58 |
-
|
59 |
-
# Clean the message content by removing leading/trailing whitespace and separators
|
60 |
-
clean_content = message_content.strip()
|
61 |
-
for separator in separators:
|
62 |
-
clean_content = clean_content.replace(separator.strip("'"), '', 1).strip()
|
63 |
-
|
64 |
-
# Append the cleaned message with its role to the list
|
65 |
-
recovered_messages.append({"role": role, "content": clean_content})
|
66 |
-
|
67 |
-
return recovered_messages
|
68 |
-
|
69 |
-
def recover_chat_messages(tokenized_chat, tokenizer):
|
70 |
-
"""
|
71 |
-
Given a tokenized_chat string and a tokenizer, returns the list of message dictionaries.
|
72 |
-
"""
|
73 |
-
jinja_template = tokenizer.chat_template
|
74 |
-
separators = extract_separators(jinja_template)
|
75 |
-
eos_token = eos_token = detect_eos_token(jinja_template, tokenizer)
|
76 |
-
recovered_messages = recover_messages(tokenized_chat, separators, eos_token)
|
77 |
-
return recovered_messages
|
78 |
-
|
79 |
-
# Example usage
|
80 |
-
if __name__ == "__main__":
|
81 |
-
checkpoint = "Qwen/Qwen1.5-0.5B"
|
82 |
-
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
83 |
-
|
84 |
-
messages = [
|
85 |
-
{
|
86 |
-
"role": "system",
|
87 |
-
"content": "You are a friendly chatbot who always responds in the style of a pirate",
|
88 |
-
},
|
89 |
-
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
|
90 |
-
]
|
91 |
-
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=False)
|
92 |
-
print(tokenized_chat)
|
93 |
-
|
94 |
-
recovered_messages = recover_chat_messages(tokenized_chat, tokenizer)
|
95 |
-
print(recovered_messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|