bstraehle commited on
Commit
9549818
1 Parent(s): 89e77a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -8
app.py CHANGED
@@ -40,6 +40,16 @@ YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
40
 
41
  MODEL_NAME = "gpt-4"
42
 
 
 
 
 
 
 
 
 
 
 
43
  def invoke(openai_api_key, use_rag, rag_db, prompt):
44
  if (openai_api_key == ""):
45
  raise gr.Error("OpenAI API Key is required.")
@@ -75,14 +85,15 @@ def invoke(openai_api_key, use_rag, rag_db, prompt):
75
  # embedding = OpenAIEmbeddings(disallowed_special = ()),
76
  # persist_directory = CHROMA_DIR)
77
  # Document retrieval
78
- vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
79
- persist_directory = CHROMA_DIR)
80
- rag_chain = RetrievalQA.from_chain_type(llm,
81
- chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
82
- retriever = vector_db.as_retriever(search_kwargs = {"k": 3}),
83
- return_source_documents = True)
84
- result = rag_chain({"query": prompt})
85
- result = result["result"]
 
86
  else:
87
  chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
88
  result = chain.run({"question": prompt})
 
40
 
41
  MODEL_NAME = "gpt-4"
42
 
43
+ def document_retrieval_chroma():
44
+ vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
45
+ persist_directory = CHROMA_DIR)
46
+ rag_chain = RetrievalQA.from_chain_type(llm,
47
+ chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
48
+ retriever = vector_db.as_retriever(search_kwargs = {"k": 3}),
49
+ return_source_documents = True)
50
+ result = rag_chain({"query": prompt})
51
+ return result["result"]
52
+
53
  def invoke(openai_api_key, use_rag, rag_db, prompt):
54
  if (openai_api_key == ""):
55
  raise gr.Error("OpenAI API Key is required.")
 
85
  # embedding = OpenAIEmbeddings(disallowed_special = ()),
86
  # persist_directory = CHROMA_DIR)
87
  # Document retrieval
88
+ ##vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
89
+ ## persist_directory = CHROMA_DIR)
90
+ ##rag_chain = RetrievalQA.from_chain_type(llm,
91
+ ## chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
92
+ ## retriever = vector_db.as_retriever(search_kwargs = {"k": 3}),
93
+ ## return_source_documents = True)
94
+ ##result = rag_chain({"query": prompt})
95
+ ##result = result["result"]
96
+ result = document_retrieval_chroma()
97
  else:
98
  chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
99
  result = chain.run({"question": prompt})