Update rag_llamaindex.py
Browse files- rag_llamaindex.py +20 -2
rag_llamaindex.py
CHANGED
@@ -2,6 +2,7 @@ import os, requests
|
|
2 |
|
3 |
from llama_hub.youtube_transcript import YoutubeTranscriptReader
|
4 |
from llama_index import download_loader, PromptTemplate, ServiceContext
|
|
|
5 |
from llama_index.embeddings import OpenAIEmbedding
|
6 |
from llama_index.indices.vector_store.base import VectorStoreIndex
|
7 |
from llama_index.llms import OpenAI
|
@@ -51,6 +52,15 @@ class LlamaIndexRAG(BaseRAG):
|
|
51 |
|
52 |
return docs
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
def get_llm(self, config):
|
55 |
return OpenAI(
|
56 |
model = config["model_name"],
|
@@ -67,6 +77,7 @@ class LlamaIndexRAG(BaseRAG):
|
|
67 |
|
68 |
def get_service_context(self, config):
|
69 |
return ServiceContext.from_defaults(
|
|
|
70 |
chunk_overlap = config["chunk_overlap"],
|
71 |
chunk_size = config["chunk_size"],
|
72 |
embed_model = OpenAIEmbedding(), # embed
|
@@ -99,10 +110,17 @@ class LlamaIndexRAG(BaseRAG):
|
|
99 |
vector_store = self.get_vector_store()
|
100 |
)
|
101 |
|
|
|
|
|
102 |
query_engine = index.as_query_engine(
|
103 |
text_qa_template = PromptTemplate(os.environ["LLAMAINDEX_TEMPLATE"]),
|
104 |
-
service_context =
|
105 |
similarity_top_k = config["k"]
|
106 |
)
|
107 |
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
from llama_hub.youtube_transcript import YoutubeTranscriptReader
|
4 |
from llama_index import download_loader, PromptTemplate, ServiceContext
|
5 |
+
from llama_index.callbacks import CallbackManager, TokenCountingHandler
|
6 |
from llama_index.embeddings import OpenAIEmbedding
|
7 |
from llama_index.indices.vector_store.base import VectorStoreIndex
|
8 |
from llama_index.llms import OpenAI
|
|
|
52 |
|
53 |
return docs
|
54 |
|
55 |
+
def get_callback_manager(self, config):
|
56 |
+
token_counter = TokenCountingHandler(
|
57 |
+
tokenizer = tiktoken.encoding_for_model(config["model_name"]).encode
|
58 |
+
)
|
59 |
+
|
60 |
+
token_counter.reset_counts()
|
61 |
+
|
62 |
+
return CallbackManager([token_counter])
|
63 |
+
|
64 |
def get_llm(self, config):
|
65 |
return OpenAI(
|
66 |
model = config["model_name"],
|
|
|
77 |
|
78 |
def get_service_context(self, config):
|
79 |
return ServiceContext.from_defaults(
|
80 |
+
callback_manager = self.get_callback_manager(config),
|
81 |
chunk_overlap = config["chunk_overlap"],
|
82 |
chunk_size = config["chunk_size"],
|
83 |
embed_model = OpenAIEmbedding(), # embed
|
|
|
110 |
vector_store = self.get_vector_store()
|
111 |
)
|
112 |
|
113 |
+
service_context = self.get_service_context(config)
|
114 |
+
|
115 |
query_engine = index.as_query_engine(
|
116 |
text_qa_template = PromptTemplate(os.environ["LLAMAINDEX_TEMPLATE"]),
|
117 |
+
service_context = service_context,
|
118 |
similarity_top_k = config["k"]
|
119 |
)
|
120 |
|
121 |
+
completion = query_engine.query(prompt)
|
122 |
+
|
123 |
+
print("111 " + str(service_context.callback_manager.token_counter))
|
124 |
+
print("222 " + str(service_context.callback_manager.token_counter.total_embedding_token_count))
|
125 |
+
|
126 |
+
return completion
|