bstraehle commited on
Commit
d07fa33
·
1 Parent(s): cd4364b

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +5 -6
rag.py CHANGED
@@ -90,13 +90,12 @@ def retrieve_mongodb():
90
  OpenAIEmbeddings(disallowed_special = ()),
91
  index_name = MONGODB_INDEX_NAME)
92
 
93
- def get_llm(config, openai_api_key):
94
  return ChatOpenAI(model_name = config["model_name"],
95
- openai_api_key = openai_api_key,
96
  temperature = config["temperature"])
97
 
98
- def llm_chain(config, openai_api_key, prompt):
99
- llm_chain = LLMChain(llm = get_llm(config, openai_api_key),
100
  prompt = LLM_CHAIN_PROMPT)
101
 
102
  with get_openai_callback() as cb:
@@ -104,8 +103,8 @@ def llm_chain(config, openai_api_key, prompt):
104
 
105
  return completion, llm_chain, cb
106
 
107
- def rag_chain(config, openai_api_key, rag_option, prompt):
108
- llm = get_llm(config, openai_api_key)
109
 
110
  if (rag_option == RAG_CHROMA):
111
  db = retrieve_chroma()
 
90
  OpenAIEmbeddings(disallowed_special = ()),
91
  index_name = MONGODB_INDEX_NAME)
92
 
93
+ def get_llm(config):
94
  return ChatOpenAI(model_name = config["model_name"],
 
95
  temperature = config["temperature"])
96
 
97
+ def llm_chain(config, prompt):
98
+ llm_chain = LLMChain(llm = get_llm(config),
99
  prompt = LLM_CHAIN_PROMPT)
100
 
101
  with get_openai_callback() as cb:
 
103
 
104
  return completion, llm_chain, cb
105
 
106
+ def rag_chain(config, rag_option, prompt):
107
+ llm = get_llm(config)
108
 
109
  if (rag_option == RAG_CHROMA):
110
  db = retrieve_chroma()