Hansimov commited on
Commit
64105c5
·
1 Parent(s): bcf3537

:gem: [Feature] New TokenChecker: count tokens and check token limit

Browse files
Files changed (1) hide show
  1. networks/huggingchat_streamer.py +40 -10
networks/huggingchat_streamer.py CHANGED
@@ -24,6 +24,39 @@ from messagers.message_outputer import OpenaiStreamOutputer
24
  from messagers.message_composer import MessageComposer
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  class HuggingchatRequester:
28
  def __init__(self, model: str):
29
  if model in MODEL_MAP.keys():
@@ -175,6 +208,9 @@ class HuggingchatRequester:
175
  messages
176
  )
177
 
 
 
 
178
  self.get_hf_chat_id()
179
  self.get_conversation_id(system_prompt=system_prompt)
180
  message_id = self.get_last_message_id()
@@ -216,13 +252,6 @@ class HuggingchatStreamer:
216
  self.model = "mixtral-8x7b"
217
  self.model_fullname = MODEL_MAP[self.model]
218
  self.message_outputer = OpenaiStreamOutputer(model=self.model)
219
- # self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
220
-
221
- # def count_tokens(self, text):
222
- # tokens = self.tokenizer.encode(text)
223
- # token_count = len(tokens)
224
- # logger.note(f"Prompt Token Count: {token_count}")
225
- # return token_count
226
 
227
  def chat_response(self, messages: list[dict], verbose=False):
228
  requester = HuggingchatRequester(model=self.model)
@@ -238,10 +267,11 @@ class HuggingchatStreamer:
238
 
239
 
240
  if __name__ == "__main__":
241
- # model = "llama3-70b"
242
- model = "command-r-plus"
243
- streamer = HuggingchatStreamer(model=model)
244
 
 
245
  messages = [
246
  {
247
  "role": "system",
 
24
  from messagers.message_composer import MessageComposer
25
 
26
 
27
+ class TokenChecker:
28
+ def __init__(self, input_str: str, model: str):
29
+ self.input_str = input_str
30
+
31
+ if model in MODEL_MAP.keys():
32
+ self.model = model
33
+ else:
34
+ self.model = "mixtral-8x7b"
35
+
36
+ self.model_fullname = MODEL_MAP[self.model]
37
+
38
+ if self.model == "llama3-70b":
39
+ # As original llama3 repo is gated and requires auth,
40
+ # I use NousResearch's version as a workaround
41
+ self.tokenizer = AutoTokenizer.from_pretrained(
42
+ "NousResearch/Meta-Llama-3-70B"
43
+ )
44
+ else:
45
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
46
+
47
+ def count_tokens(self):
48
+ token_count = len(self.tokenizer.encode(self.input_str))
49
+ logger.note(f"Prompt Token Count: {token_count}")
50
+ return token_count
51
+
52
+ def check_token_limit(self):
53
+ token_limit = TOKEN_LIMIT_MAP[self.model]
54
+ token_redundancy = int(token_limit - TOKEN_RESERVED - self.count_tokens())
55
+ if token_redundancy <= 0:
56
+ raise ValueError(f"Prompt exceeded token limit: {token_limit}")
57
+ return True
58
+
59
+
60
  class HuggingchatRequester:
61
  def __init__(self, model: str):
62
  if model in MODEL_MAP.keys():
 
208
  messages
209
  )
210
 
211
+ checker = TokenChecker(input_str=system_prompt + input_prompt, model=self.model)
212
+ checker.check_token_limit()
213
+
214
  self.get_hf_chat_id()
215
  self.get_conversation_id(system_prompt=system_prompt)
216
  message_id = self.get_last_message_id()
 
252
  self.model = "mixtral-8x7b"
253
  self.model_fullname = MODEL_MAP[self.model]
254
  self.message_outputer = OpenaiStreamOutputer(model=self.model)
 
 
 
 
 
 
 
255
 
256
  def chat_response(self, messages: list[dict], verbose=False):
257
  requester = HuggingchatRequester(model=self.model)
 
267
 
268
 
269
  if __name__ == "__main__":
270
+ # model = "command-r-plus"
271
+ model = "llama3-70b"
272
+ # model = "zephyr-141b"
273
 
274
+ streamer = HuggingchatStreamer(model=model)
275
  messages = [
276
  {
277
  "role": "system",