mahynski commited on
Commit
4ebc26e
·
verified ·
1 Parent(s): 760198b

updadted mistral token counter

Browse files
Files changed (1) hide show
  1. app.py +28 -7
app.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  import tiktoken
4
  import streamlit as st
5
 
6
- # from llama_index.llms.gemini import Gemini
7
  from llama_index.core import (
8
  VectorStoreIndex,
9
  Settings,
@@ -14,10 +13,33 @@ from streamlit_pdf_viewer import pdf_viewer
14
 
15
  MAX_OUTPUT_TOKENS = 2048
16
 
 
 
 
 
 
 
 
 
 
17
 
18
- class CountGeminiTokens:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  """
20
- Count tokens in Gemini models.
21
 
22
  See: https://medium.com/google-cloud/counting-gemini-text-tokens-locally-with-the-vertex-ai-sdk-78979fea6244
23
  """
@@ -99,7 +121,6 @@ def main():
99
  # https://docs.llamaindex.ai/en/stable/module_guides/models/llms/
100
  if llm_key is not None:
101
  if provider == 'google':
102
- # raise NotImplementedError(f"{provider} is not supported yet")
103
  from llama_index.llms.gemini import Gemini
104
  from llama_index.embeddings.gemini import GeminiEmbedding
105
 
@@ -110,7 +131,7 @@ def main():
110
  temperature=temperature,
111
  max_tokens=MAX_OUTPUT_TOKENS
112
  )
113
- Settings.tokenizer = CountGeminiTokens(llm_name) #tokenization.get_tokenizer_for_model(llm_name).compute_tokens
114
  Settings.num_output = MAX_OUTPUT_TOKENS
115
  Settings.embed_model = GeminiEmbedding(
116
  model_name="models/text-embedding-004", api_key=os.environ.get("GOOGLE_API_KEY") #, title="this is a document"
@@ -141,7 +162,7 @@ def main():
141
  elif provider == 'mistralai':
142
  from llama_index.llms.mistralai import MistralAI
143
  from llama_index.embeddings.mistralai import MistralAIEmbedding
144
-
145
  os.environ['MISTRAL_API_KEY'] = str(llm_key)
146
  Settings.llm = MistralAI(
147
  model=llm_name,
@@ -150,7 +171,7 @@ def main():
150
  random_seed=42,
151
  safe_mode=True
152
  )
153
- # Settings.tokenizer = tiktoken.encoding_for_model(llm_name).encode
154
  Settings.num_output = MAX_OUTPUT_TOKENS
155
  Settings.embed_model = MistralAIEmbedding(
156
  model_name="mistral-embed",
 
3
  import tiktoken
4
  import streamlit as st
5
 
 
6
  from llama_index.core import (
7
  VectorStoreIndex,
8
  Settings,
 
13
 
14
  MAX_OUTPUT_TOKENS = 2048
15
 
16
+ class MistralTokens:
17
+ """
18
+ Returns tokens for MistralAI models.
19
+
20
+ See: https://docs.mistral.ai/guides/tokenization/
21
+ """
22
+ def __init__(self, llm_name):
23
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
24
+ self.tokenizer = MistralTokenizer.from_model(llm_name)
25
 
26
+ def __call__(self, input):
27
+ """This returns all the tokens indices in a list since LlamaIndex seems to count by calling `len()` on the tokenizer function."""
28
+ from mistral_common.protocol.instruct.messages import UserMessage
29
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
30
+
31
+ return self.tokenizer.encode_chat_completion(
32
+ ChatCompletionRequest(
33
+ tools=[],
34
+ messages=[
35
+ UserMessage(content=input)
36
+ ]
37
+ )
38
+ ).tokens
39
+
40
+ class GeminiTokens:
41
  """
42
+ Returns tokens for Gemini models.
43
 
44
  See: https://medium.com/google-cloud/counting-gemini-text-tokens-locally-with-the-vertex-ai-sdk-78979fea6244
45
  """
 
121
  # https://docs.llamaindex.ai/en/stable/module_guides/models/llms/
122
  if llm_key is not None:
123
  if provider == 'google':
 
124
  from llama_index.llms.gemini import Gemini
125
  from llama_index.embeddings.gemini import GeminiEmbedding
126
 
 
131
  temperature=temperature,
132
  max_tokens=MAX_OUTPUT_TOKENS
133
  )
134
+ Settings.tokenizer = GeminiTokens(llm_name)
135
  Settings.num_output = MAX_OUTPUT_TOKENS
136
  Settings.embed_model = GeminiEmbedding(
137
  model_name="models/text-embedding-004", api_key=os.environ.get("GOOGLE_API_KEY") #, title="this is a document"
 
162
  elif provider == 'mistralai':
163
  from llama_index.llms.mistralai import MistralAI
164
  from llama_index.embeddings.mistralai import MistralAIEmbedding
165
+
166
  os.environ['MISTRAL_API_KEY'] = str(llm_key)
167
  Settings.llm = MistralAI(
168
  model=llm_name,
 
171
  random_seed=42,
172
  safe_mode=True
173
  )
174
+ Settings.tokenizer = MistralTokens(llm_name)
175
  Settings.num_output = MAX_OUTPUT_TOKENS
176
  Settings.embed_model = MistralAIEmbedding(
177
  model_name="mistral-embed",