bstraehle commited on
Commit
3707a95
·
1 Parent(s): 6ac712e

Update rag_llamaindex.py

Browse files
Files changed (1) hide show
  1. rag_llamaindex.py +20 -2
rag_llamaindex.py CHANGED
@@ -2,6 +2,7 @@ import os, requests
2
 
3
  from llama_hub.youtube_transcript import YoutubeTranscriptReader
4
  from llama_index import download_loader, PromptTemplate, ServiceContext
 
5
  from llama_index.embeddings import OpenAIEmbedding
6
  from llama_index.indices.vector_store.base import VectorStoreIndex
7
  from llama_index.llms import OpenAI
@@ -51,6 +52,15 @@ class LlamaIndexRAG(BaseRAG):
51
 
52
  return docs
53
 
 
 
 
 
 
 
 
 
 
54
  def get_llm(self, config):
55
  return OpenAI(
56
  model = config["model_name"],
@@ -67,6 +77,7 @@ class LlamaIndexRAG(BaseRAG):
67
 
68
  def get_service_context(self, config):
69
  return ServiceContext.from_defaults(
 
70
  chunk_overlap = config["chunk_overlap"],
71
  chunk_size = config["chunk_size"],
72
  embed_model = OpenAIEmbedding(), # embed
@@ -99,10 +110,17 @@ class LlamaIndexRAG(BaseRAG):
99
  vector_store = self.get_vector_store()
100
  )
101
 
 
 
102
  query_engine = index.as_query_engine(
103
  text_qa_template = PromptTemplate(os.environ["LLAMAINDEX_TEMPLATE"]),
104
- service_context = self.get_service_context(config),
105
  similarity_top_k = config["k"]
106
  )
107
 
108
- return query_engine.query(prompt)
 
 
 
 
 
 
2
 
3
  from llama_hub.youtube_transcript import YoutubeTranscriptReader
4
  from llama_index import download_loader, PromptTemplate, ServiceContext
5
+ from llama_index.callbacks import CallbackManager, TokenCountingHandler
6
  from llama_index.embeddings import OpenAIEmbedding
7
  from llama_index.indices.vector_store.base import VectorStoreIndex
8
  from llama_index.llms import OpenAI
 
52
 
53
  return docs
54
 
55
+ def get_callback_manager(self, config):
56
+ token_counter = TokenCountingHandler(
57
+ tokenizer = tiktoken.encoding_for_model(config["model_name"]).encode
58
+ )
59
+
60
+ token_counter.reset_counts()
61
+
62
+ return CallbackManager([token_counter])
63
+
64
  def get_llm(self, config):
65
  return OpenAI(
66
  model = config["model_name"],
 
77
 
78
  def get_service_context(self, config):
79
  return ServiceContext.from_defaults(
80
+ callback_manager = self.get_callback_manager(config),
81
  chunk_overlap = config["chunk_overlap"],
82
  chunk_size = config["chunk_size"],
83
  embed_model = OpenAIEmbedding(), # embed
 
110
  vector_store = self.get_vector_store()
111
  )
112
 
113
+ service_context = self.get_service_context(config)
114
+
115
  query_engine = index.as_query_engine(
116
  text_qa_template = PromptTemplate(os.environ["LLAMAINDEX_TEMPLATE"]),
117
+ service_context = service_context,
118
  similarity_top_k = config["k"]
119
  )
120
 
121
+ completion = query_engine.query(prompt)
122
+
123
+ print("111 " + str(service_context.callback_manager.token_counter))
124
+ print("222 " + str(service_context.callback_manager.token_counter.total_embedding_token_count))
125
+
126
+ return completion