bstraehle commited on
Commit
bf1b617
1 Parent(s): 3d17aed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -26
app.py CHANGED
@@ -13,6 +13,8 @@ from langchain.prompts import PromptTemplate
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.vectorstores import Chroma
15
 
 
 
16
  from dotenv import load_dotenv, find_dotenv
17
  _ = load_dotenv(find_dotenv())
18
 
@@ -40,7 +42,7 @@ YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
40
 
41
  MODEL_NAME = "gpt-4"
42
 
43
- def document_storage_chroma():
44
  # Document loading
45
  docs = []
46
  # Load PDF
@@ -59,31 +61,14 @@ def document_storage_chroma():
59
  text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = 150,
60
  chunk_size = 1500)
61
  splits = text_splitter.split_documents(docs)
62
- # Document storage
 
 
63
  vector_db = Chroma.from_documents(documents = splits,
64
  embedding = OpenAIEmbeddings(disallowed_special = ()),
65
  persist_directory = CHROMA_DIR)
66
 
67
- def document_storage_mongodb():
68
- # Document loading
69
- docs = []
70
- # Load PDF
71
- loader = PyPDFLoader(PDF_URL)
72
- docs.extend(loader.load())
73
- # Load Web
74
- loader = WebBaseLoader(WEB_URL_1)
75
- docs.extend(loader.load())
76
- # Load YouTube
77
- loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
78
- YOUTUBE_URL_2,
79
- YOUTUBE_URL_3], YOUTUBE_DIR),
80
- OpenAIWhisperParser())
81
- docs.extend(loader.load())
82
- # Document splitting
83
- text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = 150,
84
- chunk_size = 1500)
85
- splits = text_splitter.split_documents(docs)
86
- # Document storage
87
  vector_db = Chroma.from_documents(documents = splits,
88
  embedding = OpenAIEmbeddings(disallowed_special = ()),
89
  persist_directory = CHROMA_DIR)
@@ -115,17 +100,16 @@ def invoke(openai_api_key, rag_option, prompt):
115
  raise gr.Error("Retrieval Augmented Generation is required.")
116
  if (prompt == ""):
117
  raise gr.Error("Prompt is required.")
118
-
119
  try:
120
  llm = ChatOpenAI(model_name = MODEL_NAME,
121
  openai_api_key = openai_api_key,
122
  temperature = 0)
123
-
124
  if (rag_option == "Chroma"):
125
- #document_storage_chroma()
126
  result = document_retrieval_chroma(llm, prompt)
127
  elif (rag_option == "MongoDB"):
128
- #document_storage_mongodb()
129
  result = document_retrieval_mongodb(llm, prompt)
130
  else:
131
  chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
 
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.vectorstores import Chroma
15
 
16
+ from pymongo import MongoClient
17
+
18
  from dotenv import load_dotenv, find_dotenv
19
  _ = load_dotenv(find_dotenv())
20
 
 
42
 
43
  MODEL_NAME = "gpt-4"
44
 
45
+ def document_loading_splitting():
46
  # Document loading
47
  docs = []
48
  # Load PDF
 
61
  text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = 150,
62
  chunk_size = 1500)
63
  splits = text_splitter.split_documents(docs)
64
+ return splits
65
+
66
+ def document_storage_chroma(splits):
67
  vector_db = Chroma.from_documents(documents = splits,
68
  embedding = OpenAIEmbeddings(disallowed_special = ()),
69
  persist_directory = CHROMA_DIR)
70
 
71
+ def document_storage_mongodb(splits):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  vector_db = Chroma.from_documents(documents = splits,
73
  embedding = OpenAIEmbeddings(disallowed_special = ()),
74
  persist_directory = CHROMA_DIR)
 
100
  raise gr.Error("Retrieval Augmented Generation is required.")
101
  if (prompt == ""):
102
  raise gr.Error("Prompt is required.")
 
103
  try:
104
  llm = ChatOpenAI(model_name = MODEL_NAME,
105
  openai_api_key = openai_api_key,
106
  temperature = 0)
107
+ #splits = document_loading_splitting()
108
  if (rag_option == "Chroma"):
109
+ #document_storage_chroma(splits)
110
  result = document_retrieval_chroma(llm, prompt)
111
  elif (rag_option == "MongoDB"):
112
+ #document_storage_mongodb(splits)
113
  result = document_retrieval_mongodb(llm, prompt)
114
  else:
115
  chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)