mahynski commited on
Commit
760198b
·
verified ·
1 Parent(s): 4cb9000

add gemini token counter

Browse files
Files changed (1) hide show
  1. app.py +19 -2
app.py CHANGED
@@ -14,6 +14,24 @@ from streamlit_pdf_viewer import pdf_viewer
14
 
15
  MAX_OUTPUT_TOKENS = 2048
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def main():
18
  with st.sidebar:
19
  st.title('Document Summarization and QA System')
@@ -84,7 +102,6 @@ def main():
84
  # raise NotImplementedError(f"{provider} is not supported yet")
85
  from llama_index.llms.gemini import Gemini
86
  from llama_index.embeddings.gemini import GeminiEmbedding
87
- from vertexai.preview import tokenization
88
 
89
  os.environ['GOOGLE_API_KEY'] = str(llm_key)
90
  Settings.llm = Gemini(
@@ -93,7 +110,7 @@ def main():
93
  temperature=temperature,
94
  max_tokens=MAX_OUTPUT_TOKENS
95
  )
96
- Settings.tokenizer = tokenization.get_tokenizer_for_model(llm_name).compute_tokens
97
  Settings.num_output = MAX_OUTPUT_TOKENS
98
  Settings.embed_model = GeminiEmbedding(
99
  model_name="models/text-embedding-004", api_key=os.environ.get("GOOGLE_API_KEY") #, title="this is a document"
 
14
 
15
  MAX_OUTPUT_TOKENS = 2048
16
 
17
+
18
+ class CountGeminiTokens:
19
+ """
20
+ Count tokens in Gemini models.
21
+
22
+ See: https://medium.com/google-cloud/counting-gemini-text-tokens-locally-with-the-vertex-ai-sdk-78979fea6244
23
+ """
24
+ def __init__(self, llm_name):
25
+ from vertexai.preview import tokenization
26
+ self.tokenizer = tokenization.get_tokenizer_for_model(llm_name)
27
+
28
+ def __call__(self, input):
29
+ """This returns all the tokens in a list since LlamaIndex seems to count by calling `len()` on the tokenizer function."""
30
+ tokens = []
31
+ for list in self.tokenizer.compute_tokens(input).token_info_list:
32
+ tokens += list.tokens
33
+ return tokens
34
+
35
  def main():
36
  with st.sidebar:
37
  st.title('Document Summarization and QA System')
 
102
  # raise NotImplementedError(f"{provider} is not supported yet")
103
  from llama_index.llms.gemini import Gemini
104
  from llama_index.embeddings.gemini import GeminiEmbedding
 
105
 
106
  os.environ['GOOGLE_API_KEY'] = str(llm_key)
107
  Settings.llm = Gemini(
 
110
  temperature=temperature,
111
  max_tokens=MAX_OUTPUT_TOKENS
112
  )
113
+ Settings.tokenizer = CountGeminiTokens(llm_name) #tokenization.get_tokenizer_for_model(llm_name).compute_tokens
114
  Settings.num_output = MAX_OUTPUT_TOKENS
115
  Settings.embed_model = GeminiEmbedding(
116
  model_name="models/text-embedding-004", api_key=os.environ.get("GOOGLE_API_KEY") #, title="this is a document"