Update rag.py
Browse files
rag.py
CHANGED
@@ -98,14 +98,22 @@ def llm_chain(openai_api_key, prompt):
|
|
98 |
llm_chain = LLMChain(llm = get_llm(openai_api_key),
|
99 |
prompt = LLM_CHAIN_PROMPT,
|
100 |
verbose = False)
|
|
|
101 |
completion = llm_chain.generate([{"question": prompt}])
|
|
|
102 |
return completion, llm_chain
|
103 |
|
104 |
-
def rag_chain(openai_api_key, prompt
|
105 |
-
|
|
|
|
|
|
|
|
|
106 |
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
|
107 |
retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
|
108 |
return_source_documents = True,
|
109 |
verbose = False)
|
|
|
110 |
completion = rag_chain({"query": prompt})
|
|
|
111 |
return completion, rag_chain
|
|
|
98 |
llm_chain = LLMChain(llm = get_llm(openai_api_key),
|
99 |
prompt = LLM_CHAIN_PROMPT,
|
100 |
verbose = False)
|
101 |
+
|
102 |
completion = llm_chain.generate([{"question": prompt}])
|
103 |
+
|
104 |
return completion, llm_chain
|
105 |
|
106 |
+
def rag_chain(openai_api_key, prompt):
|
107 |
+
llm = get_llm(openai_api_key)
|
108 |
+
|
109 |
+
db = document_retrieval_chroma(llm, prompt)
|
110 |
+
|
111 |
+
rag_chain = RetrievalQA.from_chain_type(llm,
|
112 |
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
|
113 |
retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
|
114 |
return_source_documents = True,
|
115 |
verbose = False)
|
116 |
+
|
117 |
completion = rag_chain({"query": prompt})
|
118 |
+
|
119 |
return completion, rag_chain
|