Update tokenization_phi3_small.py

#29
by XirenZhou - opened
Files changed (1) hide show
  1. tokenization_phi3_small.py +26 -1
tokenization_phi3_small.py CHANGED
@@ -5,6 +5,7 @@ from typing import Collection, List, Optional, Dict, Set, Tuple, Union
5
  from functools import cached_property
6
 
7
  import base64
 
8
 
9
  from transformers import PreTrainedTokenizer, AddedToken, AutoConfig
10
  from transformers.models.auto.tokenization_auto import get_tokenizer_config
@@ -102,7 +103,31 @@ class Phi3SmallTokenizer(PreTrainedTokenizer):
102
  super().__init__(**kwargs)
103
  self.errors = errors
104
 
105
- base = tiktoken.get_encoding("cl100k_base")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if vocab_file is None:
107
  self.mergeable_ranks: Dict[bytes, int] = base._mergeable_ranks
108
  else:
 
5
  from functools import cached_property
6
 
7
  import base64
8
+ import requests
9
 
10
  from transformers import PreTrainedTokenizer, AddedToken, AutoConfig
11
  from transformers.models.auto.tokenization_auto import get_tokenizer_config
 
103
  super().__init__(**kwargs)
104
  self.errors = errors
105
 
106
+ try:
107
+ base = tiktoken.get_encoding("cl100k_base")
108
+ # This deals with the scenario where user has restricted internet access
109
+ # and thus fails to download the tokenizer file from https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken
110
+ # It is assumed that user should be able to access files on huggingface hub.
111
+ except requests.RequestException:
112
+ import hashlib
113
+ from transformers.utils import cached_file
114
+ cached_tokenizer_path = cached_file(
115
+ "microsoft/Phi-3-small-8k-instruct",
116
+ "cl100k_base.tiktoken",
117
+ _raise_exceptions_for_gated_repo=False,
118
+ _raise_exceptions_for_missing_entries=False,
119
+ _raise_exceptions_for_connection_errors=False
120
+ )
121
+ tiktoken_cache_dir = os.path.dirname(cached_tokenizer_path)
122
+ tiktoken_cache_path = os.path.join(
123
+ tiktoken_cache_dir,
124
+ hashlib.sha1("https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken".encode()).hexdigest()
125
+ )
126
+ if not os.path.exists(tiktoken_cache_path):
127
+ os.rename(cached_tokenizer_path, tiktoken_cache_path)
128
+ os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
129
+ base = tiktoken.get_encoding("cl100k_base")
130
+
131
  if vocab_file is None:
132
  self.mergeable_ranks: Dict[bytes, int] = base._mergeable_ranks
133
  else: