hbui commited on
Commit
1bc345a
·
verified ·
1 Parent(s): f55600a

Update models/langOpen.py

Browse files
Files changed (1) hide show
  1. models/langOpen.py +11 -1
models/langOpen.py CHANGED
@@ -9,6 +9,10 @@ from langchain.embeddings import OpenAIEmbeddings
9
  from langchain.prompts import PromptTemplate
10
  from langchain_pinecone import PineconeVectorStore
11
 
 
 
 
 
12
  prompt_template = """Answer the question using the given context to the best of your ability.
13
  If you don't know, answer I don't know.
14
  Context: {context}
@@ -33,7 +37,13 @@ class LangOpen:
33
  def get_response(self, query_str):
34
  print("query_str: ", query_str)
35
  print("model_name: ", self.llm.model_name)
36
- docs = self.index.similarity_search(query_str, k=4)
 
 
 
 
 
 
37
  inputs = [{"context": doc.page_content, "topic": query_str} for doc in docs]
38
  result = self.chain.apply(inputs)[0]["text"]
39
  return result
 
9
  from langchain.prompts import PromptTemplate
10
  from langchain_pinecone import PineconeVectorStore
11
 
12
+ from langchain.retrievers import ContextualCompressionRetriever
13
+ from langchain.retrievers.document_compressors import CohereRerank
14
+ from langchain_community.llms import Cohere
15
+
16
  prompt_template = """Answer the question using the given context to the best of your ability.
17
  If you don't know, answer I don't know.
18
  Context: {context}
 
37
  def get_response(self, query_str):
38
  print("query_str: ", query_str)
39
  print("model_name: ", self.llm.model_name)
40
+ #docs = self.index.similarity_search(query_str, k=4)
41
+ vectorstore_retriever = self.index.as_retriever(search_type="similarity", search_kwargs={"k": 10})
42
+ compressor = CohereRerank()
43
+ compression_retriever = ContextualCompressionRetriever(
44
+ base_compressor=compressor, base_retriever=vectorstore_retriever
45
+ )
46
+ docs = compression_retriever.get_relevant_documents(query_str)
47
  inputs = [{"context": doc.page_content, "topic": query_str} for doc in docs]
48
  result = self.chain.apply(inputs)[0]["text"]
49
  return result