bstraehle commited on
Commit
f43960a
·
1 Parent(s): 7d2deb5

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +10 -2
rag.py CHANGED
@@ -98,14 +98,22 @@ def llm_chain(openai_api_key, prompt):
98
  llm_chain = LLMChain(llm = get_llm(openai_api_key),
99
  prompt = LLM_CHAIN_PROMPT,
100
  verbose = False)
 
101
  completion = llm_chain.generate([{"question": prompt}])
 
102
  return completion, llm_chain
103
 
104
- def rag_chain(openai_api_key, prompt, db):
105
- rag_chain = RetrievalQA.from_chain_type(get_llm(openai_api_key),
 
 
 
 
106
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
107
  retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
108
  return_source_documents = True,
109
  verbose = False)
 
110
  completion = rag_chain({"query": prompt})
 
111
  return completion, rag_chain
 
98
  llm_chain = LLMChain(llm = get_llm(openai_api_key),
99
  prompt = LLM_CHAIN_PROMPT,
100
  verbose = False)
101
+
102
  completion = llm_chain.generate([{"question": prompt}])
103
+
104
  return completion, llm_chain
105
 
106
+ def rag_chain(openai_api_key, prompt):
107
+ llm = get_llm(openai_api_key)
108
+
109
+ db = document_retrieval_chroma(llm, prompt)
110
+
111
+ rag_chain = RetrievalQA.from_chain_type(llm,
112
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
113
  retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
114
  return_source_documents = True,
115
  verbose = False)
116
+
117
  completion = rag_chain({"query": prompt})
118
+
119
  return completion, rag_chain