Hansimov commited on
Commit
3f608c6
·
1 Parent(s): 9f5d69c

:recycle: [Refactor] Move MODELS_MAP to constants

Browse files
constants/__init__.py ADDED
File without changes
constants/models.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ MODEL_MAP = {
2
+ "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", # [Recommended]
3
+ "nous-mixtral-8x7b": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
4
+ "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.2",
5
+ "openchat-3.5": "openchat/openchat-3.5-0106",
6
+ "gemma-7b": "google/gemma-7b-it",
7
+ "default": "mistralai/Mixtral-8x7B-Instruct-v0.1",
8
+ }
networks/message_streamer.py CHANGED
@@ -1,26 +1,17 @@
1
  import json
2
  import re
3
  import requests
 
4
  from tiktoken import get_encoding as tiktoken_get_encoding
 
 
5
  from messagers.message_outputer import OpenaiStreamOutputer
 
6
  from utils.logger import logger
7
  from utils.enver import enver
8
- from transformers import AutoTokenizer
9
 
10
 
11
  class MessageStreamer:
12
- MODEL_MAP = {
13
- "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", # 72.62, fast [Recommended]
14
- "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.2", # 65.71, fast
15
- "nous-mixtral-8x7b": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
16
- "openchat-3.5": "openchat/openchat-3.5-0106",
17
- "gemma-7b": "google/gemma-7b-it",
18
- # "zephyr-7b-beta": "HuggingFaceH4/zephyr-7b-beta", # ❌ Too Slow
19
- # "llama-70b": "meta-llama/Llama-2-70b-chat-hf", # ❌ Require Pro User
20
- # "codellama-34b": "codellama/CodeLlama-34b-Instruct-hf", # ❌ Low Score
21
- # "falcon-180b": "tiiuae/falcon-180B-chat", # ❌ Require Pro User
22
- "default": "mistralai/Mixtral-8x7B-Instruct-v0.1",
23
- }
24
  STOP_SEQUENCES_MAP = {
25
  "mixtral-8x7b": "</s>",
26
  "mistral-7b": "</s>",
@@ -35,17 +26,22 @@ class MessageStreamer:
35
  "openchat-3.5": 8192,
36
  "gemma-7b": 8192,
37
  }
38
- TOKEN_RESERVED = 100
39
 
40
  def __init__(self, model: str):
41
- if model in self.MODEL_MAP.keys():
42
  self.model = model
43
  else:
44
  self.model = "default"
45
- self.model_fullname = self.MODEL_MAP[self.model]
46
  self.message_outputer = OpenaiStreamOutputer()
47
- # self.tokenizer = tiktoken_get_encoding("cl100k_base")
48
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
 
 
 
 
 
49
 
50
  def parse_line(self, line):
51
  line = line.decode("utf-8")
@@ -98,7 +94,7 @@ class MessageStreamer:
98
  token_limit = int(
99
  self.TOKEN_LIMIT_MAP[self.model]
100
  - self.TOKEN_RESERVED
101
- - self.count_tokens(prompt) * 1.35
102
  )
103
  if token_limit <= 0:
104
  raise ValueError("Prompt exceeded token limit!")
 
1
  import json
2
  import re
3
  import requests
4
+
5
  from tiktoken import get_encoding as tiktoken_get_encoding
6
+ from transformers import AutoTokenizer
7
+
8
  from messagers.message_outputer import OpenaiStreamOutputer
9
+ from constants.models import MODEL_MAP
10
  from utils.logger import logger
11
  from utils.enver import enver
 
12
 
13
 
14
  class MessageStreamer:
 
 
 
 
 
 
 
 
 
 
 
 
15
  STOP_SEQUENCES_MAP = {
16
  "mixtral-8x7b": "</s>",
17
  "mistral-7b": "</s>",
 
26
  "openchat-3.5": 8192,
27
  "gemma-7b": 8192,
28
  }
29
+ TOKEN_RESERVED = 20
30
 
31
  def __init__(self, model: str):
32
+ if model in MODEL_MAP.keys():
33
  self.model = model
34
  else:
35
  self.model = "default"
36
+ self.model_fullname = MODEL_MAP[self.model]
37
  self.message_outputer = OpenaiStreamOutputer()
38
+
39
+ if self.model == "gemma-7b":
40
+ # this is not wrong, as repo `google/gemma-7b-it` is gated and must authenticate to access it
41
+ # so I use mistral-7b as a fallback
42
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_MAP["mistral-7b"])
43
+ else:
44
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
45
 
46
  def parse_line(self, line):
47
  line = line.decode("utf-8")
 
94
  token_limit = int(
95
  self.TOKEN_LIMIT_MAP[self.model]
96
  - self.TOKEN_RESERVED
97
+ - self.count_tokens(prompt)
98
  )
99
  if token_limit <= 0:
100
  raise ValueError("Prompt exceeded token limit!")