bstraehle commited on
Commit
33f1a4f
1 Parent(s): 615c397

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -11
app.py CHANGED
@@ -50,6 +50,16 @@ def document_retrieval_chroma(llm, prompt):
50
  result = rag_chain({"query": prompt})
51
  return result["result"]
52
 
 
 
 
 
 
 
 
 
 
 
53
  def invoke(openai_api_key, rag, prompt):
54
  if (openai_api_key == ""):
55
  raise gr.Error("OpenAI API Key is required.")
@@ -57,11 +67,13 @@ def invoke(openai_api_key, rag, prompt):
57
  raise gr.Error("Retrieval Augmented Generation is required.")
58
  if (prompt == ""):
59
  raise gr.Error("Prompt is required.")
 
60
  try:
61
  llm = ChatOpenAI(model_name = MODEL_NAME,
62
  openai_api_key = openai_api_key,
63
  temperature = 0)
64
- if (rag != "None"):
 
65
  # Document loading
66
  #docs = []
67
  # Load PDF
@@ -85,15 +97,10 @@ def invoke(openai_api_key, rag, prompt):
85
  # embedding = OpenAIEmbeddings(disallowed_special = ()),
86
  # persist_directory = CHROMA_DIR)
87
  # Document retrieval
88
- ##vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
89
- ## persist_directory = CHROMA_DIR)
90
- ##rag_chain = RetrievalQA.from_chain_type(llm,
91
- ## chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
92
- ## retriever = vector_db.as_retriever(search_kwargs = {"k": 3}),
93
- ## return_source_documents = True)
94
- ##result = rag_chain({"query": prompt})
95
- ##result = result["result"]
96
  result = document_retrieval_chroma(llm, prompt)
 
 
 
97
  else:
98
  chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
99
  result = chain.run({"question": prompt})
@@ -107,8 +114,8 @@ description = """<strong>Overview:</strong> Reasoning application that demonstra
107
  <a href='""" + YOUTUBE_URL_1 + """'>YouTube</a>, <a href='""" + PDF_URL + """'>PDF</a>, and <a href='""" + WEB_URL + """'>Web</a>
108
  <strong>data on GPT-4</strong> (published after LLM knowledge cutoff).
109
  <ul style="list-style-type:square;">
110
- <li>Set "Retrieval Augmented Generation" to "<strong>False</strong>" and submit prompt "What is GPT-4?" The LLM <strong>without</strong> RAG does not know the answer.</li>
111
- <li>Set "Retrieval Augmented Generation" to "<strong>True</strong>" and submit prompt "What is GPT-4?" The LLM <strong>with</strong> RAG knows the answer.</li>
112
  <li>Experiment with prompts, e.g. "What are GPT-4's media capabilities in 3 emojis and 1 sentence?", "List GPT-4's exam scores and benchmark results.", or "Compare GPT-4 to GPT-3.5 in markdown table format."</li>
113
  <li>Experiment some more, for example "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format." or "Write a Python program that calls the GPT-4 API."</li>
114
  </ul>\n\n
 
50
  result = rag_chain({"query": prompt})
51
  return result["result"]
52
 
53
+ def document_retrieval_mongodb(llm, prompt):
54
+ vector_db = Chroma(embedding_function = OpenAIEmbeddings(),
55
+ persist_directory = CHROMA_DIR)
56
+ rag_chain = RetrievalQA.from_chain_type(llm,
57
+ chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
58
+ retriever = vector_db.as_retriever(search_kwargs = {"k": 3}),
59
+ return_source_documents = True)
60
+ result = rag_chain({"query": prompt})
61
+ return result["result"]
62
+
63
  def invoke(openai_api_key, rag, prompt):
64
  if (openai_api_key == ""):
65
  raise gr.Error("OpenAI API Key is required.")
 
67
  raise gr.Error("Retrieval Augmented Generation is required.")
68
  if (prompt == ""):
69
  raise gr.Error("Prompt is required.")
70
+
71
  try:
72
  llm = ChatOpenAI(model_name = MODEL_NAME,
73
  openai_api_key = openai_api_key,
74
  temperature = 0)
75
+
76
+ if (rag == "Chroma"):
77
  # Document loading
78
  #docs = []
79
  # Load PDF
 
97
  # embedding = OpenAIEmbeddings(disallowed_special = ()),
98
  # persist_directory = CHROMA_DIR)
99
  # Document retrieval
 
 
 
 
 
 
 
 
100
  result = document_retrieval_chroma(llm, prompt)
101
+ else if (rag == "MongoDB"):
102
+ # Document retrieval
103
+ result = document_retrieval_mongodb(llm, prompt)
104
  else:
105
  chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
106
  result = chain.run({"question": prompt})
 
114
  <a href='""" + YOUTUBE_URL_1 + """'>YouTube</a>, <a href='""" + PDF_URL + """'>PDF</a>, and <a href='""" + WEB_URL + """'>Web</a>
115
  <strong>data on GPT-4</strong> (published after LLM knowledge cutoff).
116
  <ul style="list-style-type:square;">
117
+ <li>Set "Retrieval Augmented Generation" to "<strong>None</strong>" and submit prompt "What is GPT-4?" The LLM <strong>without</strong> RAG does not know the answer.</li>
118
+ <li>Set "Retrieval Augmented Generation" to "<strong>Chroma</strong>" or "<strong>MongoDB</strong>" and submit prompt "What is GPT-4?" The LLM <strong>with</strong> RAG knows the answer.</li>
119
  <li>Experiment with prompts, e.g. "What are GPT-4's media capabilities in 3 emojis and 1 sentence?", "List GPT-4's exam scores and benchmark results.", or "Compare GPT-4 to GPT-3.5 in markdown table format."</li>
120
  <li>Experiment some more, for example "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format." or "Write a Python program that calls the GPT-4 API."</li>
121
  </ul>\n\n