Atreyu4EVR commited on
Commit
50636d8
·
verified ·
1 Parent(s): 96d0c1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -16
app.py CHANGED
@@ -77,24 +77,57 @@ def setup_advanced_rag_pipeline(model_name):
77
  # Set up language model
78
  llm = HuggingFaceHub(repo_id=model_links[model_name], model_kwargs={"temperature": 0.5, "max_length": 4000})
79
 
80
- # Set up HyDE
81
- hyde_prompt = PromptTemplate(
82
- input_variables=["question"],
83
- template="Please write a passage to answer the question\nQuestion: {question}\nPassage:"
84
- )
85
- hyde_chain = LLMChain(llm=llm, prompt=hyde_prompt)
86
 
87
- def hyde_retriever(query):
88
- hypothetical_doc = hyde_chain.run(query)
89
- hyde_embedding = embeddings.embed_query(hypothetical_doc)
90
- return vectorstore.similarity_search_by_vector(hyde_embedding, k=3)
91
 
92
- # Set up ContextualCompressionRetriever
93
- compressor = LLMChainExtractor.from_llm(llm)
94
- compression_retriever = ContextualCompressionRetriever(
95
- base_compressor=compressor,
96
- base_retriever=hyde_retriever
97
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  # Create RetrievalQA chain
100
  qa_chain = RetrievalQA.from_chain_type(
 
77
  # Set up language model
78
  llm = HuggingFaceHub(repo_id=model_links[model_name], model_kwargs={"temperature": 0.5, "max_length": 4000})
79
 
80
+ def load_and_process_json(file_path):
81
+ with open(file_path, 'r') as file:
82
+ data = json.load(file)
 
 
 
83
 
84
+ documents = data.get("documents", [])
 
 
 
85
 
86
+ if not documents:
87
+ raise ValueError("No valid documents found in JSON file.")
88
+
89
+ # Create Document objects
90
+ doc_objects = [
91
+ Document(
92
+ page_content=doc["content"],
93
+ metadata={"title": doc["title"], "id": doc["id"]}
94
+ ) for doc in documents
95
+ ]
96
+
97
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
98
+ splits = text_splitter.split_documents(doc_objects)
99
+
100
+ return splits
101
+
102
+ def get_vectorstore(file_path):
103
+ # Check if vectorstore already exists
104
+ if os.path.exists(VECTORSTORE_PATH):
105
+ print("Loading existing vectorstore...")
106
+ return Chroma(persist_directory=VECTORSTORE_PATH, embedding_function=embeddings)
107
+
108
+ print("Creating new vectorstore...")
109
+ splits = load_and_process_json(file_path)
110
+
111
+ # Process in batches
112
+ vectorstore = None
113
+ for i in tqdm(range(0, len(splits), BATCH_SIZE), desc="Processing batches"):
114
+ batch = splits[i:i+BATCH_SIZE]
115
+ if vectorstore is None:
116
+ vectorstore = Chroma.from_documents(documents=batch, embedding=embeddings, persist_directory=VECTORSTORE_PATH)
117
+ else:
118
+ vectorstore.add_documents(documents=batch)
119
+
120
+ vectorstore.persist()
121
+ return vectorstore
122
+
123
+ def setup_rag_pipeline(file_path):
124
+ vectorstore = get_vectorstore(file_path)
125
+ return RetrievalQA.from_chain_type(
126
+ llm=llm,
127
+ chain_type="stuff",
128
+ retriever=vectorstore.as_retriever(search_kwargs={"k": RETRIEVER_K}),
129
+ return_source_documents=True
130
+ )
131
 
132
  # Create RetrievalQA chain
133
  qa_chain = RetrievalQA.from_chain_type(