teddyllm commited on
Commit
0e0cf5a
1 Parent(s): ea10cb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -6,19 +6,24 @@ from langchain_openai import OpenAIEmbeddings
6
  from langchain_core.output_parsers import StrOutputParser
7
  from langchain_core.runnables import RunnablePassthrough
8
  from langchain_core.prompts import ChatPromptTemplate
 
 
 
 
9
  import gradio as gr
10
 
11
 
12
  def format_docs(docs):
13
- print(docs)
14
  return "\n\n".join(doc.page_content for doc in docs)
15
 
16
 
17
-
18
  embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
19
  db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
20
- retriever = db.as_retriever(k=5)
21
-
 
 
 
22
 
23
  llm = ChatOpenAI(model="gpt-4o")
24
 
@@ -33,7 +38,7 @@ prompt = ChatPromptTemplate.from_messages([
33
  ])
34
 
35
  rag_chain = (
36
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
37
  | prompt
38
  | llm
39
  | StrOutputParser()
@@ -48,7 +53,6 @@ def chat_gen(message, history):
48
 
49
  partial_message=""
50
  for chunk in rag_chain.stream(message):
51
- # if chunk.choices[0].delta.content is not None:
52
  partial_message = partial_message + chunk
53
  yield partial_message
54
 
@@ -59,7 +63,7 @@ demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
59
 
60
 
61
  try:
62
- demo.launch(debug=True, share=True, show_api=False)
63
  demo.close()
64
  except Exception as e:
65
  demo.close()
 
6
  from langchain_core.output_parsers import StrOutputParser
7
  from langchain_core.runnables import RunnablePassthrough
8
  from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain.retrievers import ContextualCompressionRetriever
10
+ from langchain.retrievers.document_compressors import FlashrankRerank
11
+ from langchain_openai import ChatOpenAI
12
+
13
  import gradio as gr
14
 
15
 
16
  def format_docs(docs):
 
17
  return "\n\n".join(doc.page_content for doc in docs)
18
 
19
 
 
20
  embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
21
  db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
22
+ retriever = db.as_retriever(k=10)
23
+ compressor = FlashrankRerank()
24
+ compression_retriever = ContextualCompressionRetriever(
25
+ base_compressor=compressor, base_retriever=retriever
26
+ )
27
 
28
  llm = ChatOpenAI(model="gpt-4o")
29
 
 
38
  ])
39
 
40
  rag_chain = (
41
+ {"context": compression_retriever | format_docs, "question": RunnablePassthrough()}
42
  | prompt
43
  | llm
44
  | StrOutputParser()
 
53
 
54
  partial_message=""
55
  for chunk in rag_chain.stream(message):
 
56
  partial_message = partial_message + chunk
57
  yield partial_message
58
 
 
63
 
64
 
65
  try:
66
+ demo.launch(debug=True, share=False, show_api=False)
67
  demo.close()
68
  except Exception as e:
69
  demo.close()