ksuriuri commited on
Commit
f3c1d0f
1 Parent(s): 0891486

Update tokenization_chatglm.py

Browse files

当运行如下代码:
```
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("/home/oneway/ssd2t/model/ZhipuAI/glm-4-9b-chat", trust_remote_code=True)
new_str = tokenizer.decode(198)
print(new_str)
```

报错:`TypeError: token should only be of type types or str`

原因是glm4的词表中的key是以bytes类型存储,而bytes类型在transformers的_decode函数中被遍历会变成int类型。

对`tokenization_chatglm.py`中的`convert_tokens_to_string`函数作如下修改即可解决该问题:
```
def convert_tokens_to_string(tokens: List[Union[bytes, str, int]]) -> str:
"""
Converts a sequence of tokens in a single string.
"""
text = ""
temp = b""
for t in tokens:
if isinstance(t, int):
t = chr(t)
if isinstance(t, str):
if temp:
text += temp.decode("utf-8", errors="replace")
temp = b""
text += t
elif isinstance(t, bytes):
temp += t
else:
raise TypeError("token should only be of type int, bytes or str")
if temp:
text += temp.decode("utf-8", errors="replace")
return text
```

Files changed (1) hide show
  1. tokenization_chatglm.py +4 -4
tokenization_chatglm.py CHANGED
@@ -63,22 +63,22 @@ class ChatGLM4Tokenizer(PreTrainedTokenizer):
63
  vocab.update(self.added_tokens_encoder)
64
  return vocab
65
 
66
- def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
67
  """
68
  Converts a sequence of tokens in a single string.
69
  """
70
  text = ""
71
  temp = b""
72
  for t in tokens:
 
 
73
  if isinstance(t, str):
74
  if temp:
75
  text += temp.decode("utf-8", errors="replace")
76
- temp = b""
77
- text += t
78
  elif isinstance(t, bytes):
79
  temp += t
80
  else:
81
- raise TypeError("token should only be of type types or str")
82
  if temp:
83
  text += temp.decode("utf-8", errors="replace")
84
  return text
 
63
  vocab.update(self.added_tokens_encoder)
64
  return vocab
65
 
66
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str, int]]) -> str:
67
  """
68
  Converts a sequence of tokens in a single string.
69
  """
70
  text = ""
71
  temp = b""
72
  for t in tokens:
73
+ if isinstance(t, int):
74
+ t = chr(t)
75
  if isinstance(t, str):
76
  if temp:
77
  text += temp.decode("utf-8", errors="replace")
 
 
78
  elif isinstance(t, bytes):
79
  temp += t
80
  else:
81
+ raise TypeError("token should only be of type int, bytes or str")
82
  if temp:
83
  text += temp.decode("utf-8", errors="replace")
84
  return text