karubiniumu commited on
Commit
804908b
1 Parent(s): e2a8726

hybrid + emb_only

Browse files
Files changed (2) hide show
  1. app.py +23 -11
  2. pipe.py +8 -1
app.py CHANGED
@@ -23,8 +23,9 @@ def run(pipe,q):
23
  result = pipe.run({
24
  "query_rephrase_prompt_builder":{"query":q},
25
  "prompt_builder": { "query": q},
 
26
  "memory_joiner": {"values": [ChatMessage.from_user(q)]},
27
- },include_outputs_from=["query_rephrase_llm","llm",'ranker','bm25_retriever','embedding_retriever'])
28
 
29
  query_rephrase = result['query_rephrase_llm']['replies'][0]
30
  print('query_rephrase:',query_rephrase)
@@ -36,22 +37,33 @@ def run(pipe,q):
36
 
37
  reply = result['llm']['replies'][0]
38
  docs = result['ranker']['documents']
39
- print('reply:',reply.content)
40
  print('ranker')
41
  log_docs(docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- html = '<div class="ref-title">参考記事</div><div class="ref">'
44
- for doc in get_unique_docs(docs) :
45
- title = doc.meta['title']
46
- link = doc.meta['link']
47
- row = f'<div><a class="link" href="{link}" target="_blank">{title}</a></div>'
48
- html += row
49
- html += '</div>'
50
- return {'reply':reply.content,'sources':html}
51
 
52
  def rag(q,history):
53
  result = run(pipe,q)
54
- return result['reply'] + result['sources']
 
55
 
56
  app = gr.ChatInterface(
57
  rag,
 
23
  result = pipe.run({
24
  "query_rephrase_prompt_builder":{"query":q},
25
  "prompt_builder": { "query": q},
26
+ "prompt_builder_emb_only": { "query": q},
27
  "memory_joiner": {"values": [ChatMessage.from_user(q)]},
28
+ },include_outputs_from=["query_rephrase_llm","llm",'ranker','bm25_retriever','embedding_retriever','llm_emb_only'])
29
 
30
  query_rephrase = result['query_rephrase_llm']['replies'][0]
31
  print('query_rephrase:',query_rephrase)
 
37
 
38
  reply = result['llm']['replies'][0]
39
  docs = result['ranker']['documents']
 
40
  print('ranker')
41
  log_docs(docs)
42
+ print('reply:',reply.content)
43
+
44
+ def create_response(reply,docs) :
45
+ html = '<div class="ref-title">参考記事</div><div class="ref">'
46
+ for doc in get_unique_docs(docs) :
47
+ title = doc.meta['title']
48
+ link = doc.meta['link']
49
+ row = f'<div><a class="link" href="{link}" target="_blank">{title}</a></div>'
50
+ html += row
51
+ html += '</div>'
52
+ return {'reply':reply.content,'sources':html}
53
+
54
+ hybrid_response = create_response(reply,docs)
55
+
56
+ reply_emb_only = result['llm_emb_only']['replies'][0]
57
+ print('reply_emb_only :',reply_emb_only.content)
58
 
59
+ emb_only_response = create_response(reply_emb_only,result['embedding_retriever']['documents'])
60
+
61
+ return {'hybrid':hybrid_response,'emb_only' :emb_only_response}
 
 
 
 
 
62
 
63
  def rag(q,history):
64
  result = run(pipe,q)
65
+ return '<h3>ハイブリッド検索</h3>' + result['hybrid']['reply'] + result['hybrid']['sources'] + \
66
+ '<h3>ベクトル検索のみ</h3>'+ result['emb_only']['reply'] + result['emb_only']['sources']
67
 
68
  app = gr.ChatInterface(
69
  rag,
pipe.py CHANGED
@@ -92,6 +92,8 @@ document_joiner = DocumentJoiner()
92
  ranker = TransformersSimilarityRanker(model=ranker_model,meta_fields_to_embed=['title'],model_kwargs={'trust_remote_code':True})
93
  prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
94
  gemini = GoogleAIGeminiChatGenerator(model=gemini_model)
 
 
95
 
96
  memory_store = InMemoryChatMessageStore()
97
  memory_joiner = ListJoiner(List[ChatMessage])
@@ -113,6 +115,8 @@ pipe.add_component("llm", gemini)
113
  pipe.add_component("memory_retriever", memory_retriever)
114
  pipe.add_component("memory_joiner", memory_joiner)
115
  pipe.add_component("memory_writer", memory_writer)
 
 
116
 
117
  pipe.connect("memory_retriever", "query_rephrase_prompt_builder.memories")
118
  pipe.connect("query_rephrase_prompt_builder.prompt", "query_rephrase_llm")
@@ -129,4 +133,7 @@ pipe.connect("ranker.documents", "prompt_builder.documents")
129
  pipe.connect("memory_retriever", "prompt_builder.memories")
130
  pipe.connect("prompt_builder.prompt", "llm.messages")
131
  pipe.connect("llm.replies", "memory_joiner")
132
- pipe.connect("memory_joiner", "memory_writer")
 
 
 
 
92
  ranker = TransformersSimilarityRanker(model=ranker_model,meta_fields_to_embed=['title'],model_kwargs={'trust_remote_code':True})
93
  prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
94
  gemini = GoogleAIGeminiChatGenerator(model=gemini_model)
95
+ prompt_builder_emb_only = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
96
+ gemini_emb_only = GoogleAIGeminiChatGenerator(model=gemini_model)
97
 
98
  memory_store = InMemoryChatMessageStore()
99
  memory_joiner = ListJoiner(List[ChatMessage])
 
115
  pipe.add_component("memory_retriever", memory_retriever)
116
  pipe.add_component("memory_joiner", memory_joiner)
117
  pipe.add_component("memory_writer", memory_writer)
118
+ pipe.add_component("prompt_builder_emb_only", prompt_builder_emb_only)
119
+ pipe.add_component("llm_emb_only", gemini_emb_only)
120
 
121
  pipe.connect("memory_retriever", "query_rephrase_prompt_builder.memories")
122
  pipe.connect("query_rephrase_prompt_builder.prompt", "query_rephrase_llm")
 
133
  pipe.connect("memory_retriever", "prompt_builder.memories")
134
  pipe.connect("prompt_builder.prompt", "llm.messages")
135
  pipe.connect("llm.replies", "memory_joiner")
136
+ pipe.connect("memory_joiner", "memory_writer")
137
+ pipe.connect("embedding_retriever.documents", "prompt_builder_emb_only.documents")
138
+ pipe.connect("memory_retriever", "prompt_builder_emb_only.memories")
139
+ pipe.connect("prompt_builder_emb_only.prompt", "llm_emb_only.messages")