mahynski commited on
Commit
0444d24
1 Parent(s): f5ca364

updated token counts and context windows explicitly for each model

Browse files
Files changed (1) hide show
  1. app.py +40 -17
app.py CHANGED
@@ -11,8 +11,6 @@ from llama_index.core import (
11
  from llama_parse import LlamaParse
12
  from streamlit_pdf_viewer import pdf_viewer
13
 
14
- MAX_OUTPUT_TOKENS = 2048
15
-
16
  class MistralTokens:
17
  """
18
  Returns tokens for MistralAI models.
@@ -71,11 +69,11 @@ def main():
71
 
72
  # Select LLM
73
  if provider == 'google':
74
- llm_list = ['gemini-1.0-pro', 'gemini-1.5-flash', 'gemini-1.5-pro', 'aqa']
75
  elif provider == 'huggingface':
76
  llm_list = []
77
  elif provider == 'mistralai':
78
- llm_list = ["mistral-small-latest", "mistral-large-latest", "open-mistral-nemo-latest"]
79
  elif provider == 'openai':
80
  llm_list = ['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'gpt-4o', 'gpt-4o-mini']
81
  else:
@@ -127,75 +125,100 @@ def main():
127
  if provider == 'google':
128
  from llama_index.llms.gemini import Gemini
129
  from llama_index.embeddings.gemini import GeminiEmbedding
130
-
 
131
  os.environ['GOOGLE_API_KEY'] = str(llm_key)
132
  Settings.llm = Gemini(
133
  model=f"models/{llm_name}",
134
  token=os.environ.get("GOOGLE_API_KEY"),
135
  temperature=temperature,
136
- max_tokens=MAX_OUTPUT_TOKENS
137
  )
138
  Settings.tokenizer = GeminiTokens(llm_name)
139
- Settings.num_output = MAX_OUTPUT_TOKENS
140
  Settings.embed_model = GeminiEmbedding(
141
  model_name="models/text-embedding-004", api_key=os.environ.get("GOOGLE_API_KEY") #, title="this is a document"
142
  )
143
- Settings.context_window = 4096
 
 
 
 
144
  elif provider == 'huggingface':
145
  if llm_name is not None and embed_name is not None:
146
  from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
147
  from llama_index.embeddings.huggingface import HuggingFaceInferenceAPIEmbedding
148
  from transformers import AutoTokenizer
149
 
 
 
150
  os.environ['HFTOKEN'] = str(llm_key)
151
  Settings.llm = HuggingFaceInferenceAPI(
152
  model_name=llm_name,
153
  token=os.environ.get("HFTOKEN"),
154
  temperature=temperature,
155
- max_tokens=MAX_OUTPUT_TOKENS
156
  )
157
  Settings.tokenizer = AutoTokenizer.from_pretrained(
158
  llm_name,
159
  token=os.environ.get("HFTOKEN"),
160
  )
161
- Settings.num_output = MAX_OUTPUT_TOKENS
162
  Settings.embed_model = HuggingFaceInferenceAPIEmbedding(
163
  model_name=embed_name
164
  )
165
- Settings.context_window = 4096
166
  elif provider == 'mistralai':
167
  from llama_index.llms.mistralai import MistralAI
168
  from llama_index.embeddings.mistralai import MistralAIEmbedding
 
169
 
170
  os.environ['MISTRAL_API_KEY'] = str(llm_key)
171
  Settings.llm = MistralAI(
172
  model=llm_name,
173
  temperature=temperature,
174
- max_tokens=MAX_OUTPUT_TOKENS,
175
  random_seed=42,
176
  safe_mode=True
177
  )
178
  Settings.tokenizer = MistralTokens(llm_name)
179
- Settings.num_output = MAX_OUTPUT_TOKENS
180
  Settings.embed_model = MistralAIEmbedding(
181
  model_name="mistral-embed",
182
  api_key=os.environ.get("MISTRAL_API_KEY")
183
  )
184
- Settings.context_window = 4096 # max possible
185
  elif provider == 'openai':
186
  from llama_index.llms.openai import OpenAI
187
  from llama_index.embeddings.openai import OpenAIEmbedding
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  os.environ["OPENAI_API_KEY"] = str(llm_key)
190
  Settings.llm = OpenAI(
191
  model=llm_name,
192
  temperature=temperature,
193
- max_tokens=MAX_OUTPUT_TOKENS
194
  )
195
  Settings.tokenizer = tiktoken.encoding_for_model(llm_name).encode
196
- Settings.num_output = MAX_OUTPUT_TOKENS
197
  Settings.embed_model = OpenAIEmbedding()
198
- Settings.context_window = 4096 # max possible
199
  else:
200
  raise NotImplementedError(f"{provider} is not supported yet")
201
 
 
11
  from llama_parse import LlamaParse
12
  from streamlit_pdf_viewer import pdf_viewer
13
 
 
 
14
  class MistralTokens:
15
  """
16
  Returns tokens for MistralAI models.
 
69
 
70
  # Select LLM
71
  if provider == 'google':
72
+ llm_list = ['gemini-1.0-pro', 'gemini-1.5-flash', 'gemini-1.5-pro']
73
  elif provider == 'huggingface':
74
  llm_list = []
75
  elif provider == 'mistralai':
76
+ llm_list = ["mistral-large-latest", "open-mistral-nemo-latest"]
77
  elif provider == 'openai':
78
  llm_list = ['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'gpt-4o', 'gpt-4o-mini']
79
  else:
 
125
  if provider == 'google':
126
  from llama_index.llms.gemini import Gemini
127
  from llama_index.embeddings.gemini import GeminiEmbedding
128
+ max_output_tokens = 8192 # https://firebase.google.com/docs/vertex-ai/gemini-models
129
+
130
  os.environ['GOOGLE_API_KEY'] = str(llm_key)
131
  Settings.llm = Gemini(
132
  model=f"models/{llm_name}",
133
  token=os.environ.get("GOOGLE_API_KEY"),
134
  temperature=temperature,
135
+ max_tokens=max_output_tokens
136
  )
137
  Settings.tokenizer = GeminiTokens(llm_name)
138
+ Settings.num_output = max_output_tokens
139
  Settings.embed_model = GeminiEmbedding(
140
  model_name="models/text-embedding-004", api_key=os.environ.get("GOOGLE_API_KEY") #, title="this is a document"
141
  )
142
+ if llm_name == 'gemini-1.0-pro':
143
+ total_token_limit = 32760
144
+ else:
145
+ total_token_limit = 1e6
146
+ Settings.context_window = total_token_limit - max_output_tokens # Gemini counts total tokens
147
  elif provider == 'huggingface':
148
  if llm_name is not None and embed_name is not None:
149
  from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
150
  from llama_index.embeddings.huggingface import HuggingFaceInferenceAPIEmbedding
151
  from transformers import AutoTokenizer
152
 
153
+ max_output_tokens = 2048 # Just a generic value
154
+
155
  os.environ['HFTOKEN'] = str(llm_key)
156
  Settings.llm = HuggingFaceInferenceAPI(
157
  model_name=llm_name,
158
  token=os.environ.get("HFTOKEN"),
159
  temperature=temperature,
160
+ max_tokens=max_output_tokens
161
  )
162
  Settings.tokenizer = AutoTokenizer.from_pretrained(
163
  llm_name,
164
  token=os.environ.get("HFTOKEN"),
165
  )
166
+ Settings.num_output = max_output_tokens
167
  Settings.embed_model = HuggingFaceInferenceAPIEmbedding(
168
  model_name=embed_name
169
  )
170
+ Settings.context_window = 4096 # Just a generic value
171
  elif provider == 'mistralai':
172
  from llama_index.llms.mistralai import MistralAI
173
  from llama_index.embeddings.mistralai import MistralAIEmbedding
174
+ max_output_tokens = 8192 # Based on internet consensus since this is not well documented
175
 
176
  os.environ['MISTRAL_API_KEY'] = str(llm_key)
177
  Settings.llm = MistralAI(
178
  model=llm_name,
179
  temperature=temperature,
180
+ max_tokens=max_output_tokens,
181
  random_seed=42,
182
  safe_mode=True
183
  )
184
  Settings.tokenizer = MistralTokens(llm_name)
185
+ Settings.num_output = max_output_tokens
186
  Settings.embed_model = MistralAIEmbedding(
187
  model_name="mistral-embed",
188
  api_key=os.environ.get("MISTRAL_API_KEY")
189
  )
190
+ Settings.context_window = 128000 # 128k for flagship models - doesn't seem to count input tokens
191
  elif provider == 'openai':
192
  from llama_index.llms.openai import OpenAI
193
  from llama_index.embeddings.openai import OpenAIEmbedding
194
 
195
+ # https://platform.openai.com/docs/models/gpt-4-turbo-and-gpt-4
196
+ if llm_name == 'gpt-3.5-turbo':
197
+ max_output_tokens = 4096
198
+ context_window = 16385
199
+ elif llm_name == 'gpt-4' :
200
+ max_output_tokens = 8192
201
+ context_window = 8192
202
+ elif llm_name == 'gpt-4-turbo'
203
+ max_output_tokens = 4096
204
+ context_window = 128000
205
+ elif llm_name == 'gpt-4o':
206
+ max_output_tokens = 4096
207
+ context_window = 128000
208
+ elif llm_name == 'gpt-4o-mini':
209
+ max_output_tokens = 16384
210
+ context_window = 128000
211
+
212
  os.environ["OPENAI_API_KEY"] = str(llm_key)
213
  Settings.llm = OpenAI(
214
  model=llm_name,
215
  temperature=temperature,
216
+ max_tokens=max_output_tokens
217
  )
218
  Settings.tokenizer = tiktoken.encoding_for_model(llm_name).encode
219
+ Settings.num_output = max_output_tokens
220
  Settings.embed_model = OpenAIEmbedding()
221
+ Settings.context_window = context_window
222
  else:
223
  raise NotImplementedError(f"{provider} is not supported yet")
224