bstraehle commited on
Commit
a6c63bf
·
1 Parent(s): 7eac7c9

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +5 -5
rag.py CHANGED
@@ -89,20 +89,20 @@ def document_retrieval_mongodb(llm, prompt):
89
  OpenAIEmbeddings(disallowed_special = ()),
90
  index_name = MONGODB_INDEX_NAME)
91
 
92
- def get_llm():
93
  return ChatOpenAI(model_name = config["model_name"],
94
  openai_api_key = openai_api_key,
95
  temperature = config["temperature"])
96
 
97
- def llm_chain(prompt):
98
- llm_chain = LLMChain(llm = get_llm(),
99
  prompt = LLM_CHAIN_PROMPT,
100
  verbose = False)
101
  completion = llm_chain.generate([{"question": prompt}])
102
  return completion, llm_chain
103
 
104
- def rag_chain(prompt, db):
105
- rag_chain = RetrievalQA.from_chain_type(get_llm(),
106
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
107
  retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
108
  return_source_documents = True,
 
89
  OpenAIEmbeddings(disallowed_special = ()),
90
  index_name = MONGODB_INDEX_NAME)
91
 
92
+ def get_llm(openai_api_key):
93
  return ChatOpenAI(model_name = config["model_name"],
94
  openai_api_key = openai_api_key,
95
  temperature = config["temperature"])
96
 
97
+ def llm_chain(openai_api_key, prompt):
98
+ llm_chain = LLMChain(llm = get_llm(openai_api_key),
99
  prompt = LLM_CHAIN_PROMPT,
100
  verbose = False)
101
  completion = llm_chain.generate([{"question": prompt}])
102
  return completion, llm_chain
103
 
104
+ def rag_chain(openai_api_key, prompt, db):
105
+ rag_chain = RetrievalQA.from_chain_type(get_llm(openai_api_key),
106
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
107
  retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
108
  return_source_documents = True,