bstraehle commited on
Commit
3ede494
1 Parent(s): 516ec1c

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +80 -0
rag.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
2
  WEB_URL = "https://openai.com/research/gpt-4"
3
  YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
@@ -14,3 +16,81 @@ MONGODB_INDEX_NAME = "default"
14
 
15
  LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
16
  RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
  PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
4
  WEB_URL = "https://openai.com/research/gpt-4"
5
  YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
 
16
 
17
  LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"], template = os.environ["LLM_TEMPLATE"])
18
  RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], template = os.environ["RAG_TEMPLATE"])
19
+
20
+ RAG_OFF = "Off"
21
+ RAG_CHROMA = "Chroma"
22
+ RAG_MONGODB = "MongoDB"
23
+
24
+ client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
25
+ collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
26
+
27
+ config = {
28
+ "chunk_overlap": 150,
29
+ "chunk_size": 1500,
30
+ "k": 3,
31
+ "model_name": "gpt-4-0613",
32
+ "temperature": 0,
33
+ }
34
+
35
+ def document_loading_splitting():
36
+ # Document loading
37
+ docs = []
38
+
39
+ # Load PDF
40
+ loader = PyPDFLoader(PDF_URL)
41
+ docs.extend(loader.load())
42
+
43
+ # Load Web
44
+ loader = WebBaseLoader(WEB_URL)
45
+ docs.extend(loader.load())
46
+
47
+ # Load YouTube
48
+ loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
49
+ YOUTUBE_URL_2,
50
+ YOUTUBE_URL_3], YOUTUBE_DIR),
51
+ OpenAIWhisperParser())
52
+ docs.extend(loader.load())
53
+
54
+ # Document splitting
55
+ text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
56
+ chunk_size = config["chunk_size"])
57
+ split_documents = text_splitter.split_documents(docs)
58
+
59
+ return split_documents
60
+
61
+ def document_storage_chroma(documents):
62
+ Chroma.from_documents(documents = documents,
63
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
64
+ persist_directory = CHROMA_DIR)
65
+
66
+ def document_storage_mongodb(documents):
67
+ MongoDBAtlasVectorSearch.from_documents(documents = documents,
68
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
69
+ collection = collection,
70
+ index_name = MONGODB_INDEX_NAME)
71
+
72
+ def document_retrieval_chroma(llm, prompt):
73
+ return Chroma(embedding_function = OpenAIEmbeddings(),
74
+ persist_directory = CHROMA_DIR)
75
+
76
+ def document_retrieval_mongodb(llm, prompt):
77
+ return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
78
+ MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
79
+ OpenAIEmbeddings(disallowed_special = ()),
80
+ index_name = MONGODB_INDEX_NAME)
81
+
82
+ def llm_chain(llm, prompt):
83
+ llm_chain = LLMChain(llm = llm,
84
+ prompt = LLM_CHAIN_PROMPT,
85
+ verbose = False)
86
+ completion = llm_chain.generate([{"question": prompt}])
87
+ return completion, llm_chain
88
+
89
+ def rag_chain(llm, prompt, db):
90
+ rag_chain = RetrievalQA.from_chain_type(llm,
91
+ chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
92
+ retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
93
+ return_source_documents = True,
94
+ verbose = False)
95
+ completion = rag_chain({"query": prompt})
96
+ return completion, rag_chain