karubiniumu commited on
Commit
57957da
1 Parent(s): b335705
Files changed (1) hide show
  1. app.py +74 -38
app.py CHANGED
@@ -9,42 +9,63 @@ from haystack.components.embedders import SentenceTransformersTextEmbedder
9
  from haystack.components.joiners import DocumentJoiner
10
  from haystack.components.rankers import TransformersSimilarityRanker
11
  from haystack import Pipeline,component,Document
12
- from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator
13
- from haystack.components.builders import PromptBuilder
14
-
15
- @component
16
- class AnswerBuilder :
17
- @component.output_types(reply=str,documents=list[Document])
18
- def run(self,replies:list[str],documents:list[Document]):
19
- reply = replies[0]
20
- return {"reply":reply,'documents':documents}
21
 
22
  document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
23
  print('document_store loaded' ,document_store.count_documents())
24
 
 
 
 
 
 
 
 
 
25
  class Niwa_rag :
26
  def __init__(self):
27
  self.createPipe()
28
  def createPipe(self):
29
- template = """
30
- 以下の情報に基づいて質問に答えて下さい。
 
 
 
 
 
31
 
32
- Context:
33
  {% for document in documents %}
34
  {{ document.content }}
35
  {% endfor %}
36
 
37
- 質問: {{question}}
38
  回答:
39
  """
 
 
 
 
40
  text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model)
41
  embedding_retriever = InMemoryEmbeddingRetriever(document_store)
42
  bm25_retriever = InMemoryBM25Retriever(document_store)
43
  document_joiner = DocumentJoiner()
44
- ranker = TransformersSimilarityRanker(model=ranker_model,top_k=6)
45
- prompt_builder = PromptBuilder(template=template)
46
- gemini = GoogleAIGeminiGenerator(model="models/gemini-1.0-pro")
47
- answer_builder = AnswerBuilder()
 
 
 
 
48
 
49
  pipe = Pipeline()
50
  pipe.add_component("text_embedder", text_embedder)
@@ -54,43 +75,58 @@ class Niwa_rag :
54
  pipe.add_component("ranker", ranker)
55
  pipe.add_component("prompt_builder", prompt_builder)
56
  pipe.add_component("llm", gemini)
57
- pipe.add_component("answer_builder", answer_builder)
 
 
58
 
59
  pipe.connect("text_embedder", "embedding_retriever")
60
  pipe.connect("bm25_retriever", "document_joiner")
61
  pipe.connect("embedding_retriever", "document_joiner")
62
  pipe.connect("document_joiner", "ranker")
63
  pipe.connect("ranker.documents", "prompt_builder.documents")
64
- pipe.connect("prompt_builder.prompt", "llm")
65
- pipe.connect("llm.replies", "answer_builder.replies")
66
- pipe.connect("ranker.documents", "answer_builder.documents")
 
 
67
  self.pipe = pipe
68
  def run(self,q):
 
69
  if not q :
70
- return ['','']
71
- result = self.pipe.run(
72
- {"text_embedder": {"text": q}, "bm25_retriever": {"query": q}, "ranker": {"query": q},'prompt_builder':{'question':q}}
73
- )
74
- reply = result['answer_builder']['reply']
75
- html = '<div>参考記事</div><div style="margin-left:30px">'
76
- for doc in result['answer_builder']['documents'] :
 
 
 
 
 
 
77
  title = doc.meta['title']
78
  link = doc.meta['link']
79
- row = f'<div><a href="{link}" target="_blank">{title}</a></div>'
80
  html += row
 
81
  html += '</div>'
82
- return [reply,html]
83
 
84
  rag = Niwa_rag()
85
 
86
- with gr.Blocks() as app:
87
- inputs=gr.Textbox(label='質問')
88
- submit = gr.Button("送信",variant="primary")
89
- reply =gr.Textbox(label='回答')
90
- sources =gr.HTML(label='参考記事')
91
- submit.click(lambda: gr.update(interactive=False),inputs=None, outputs=submit) \
92
- .then(fn=rag.run, inputs=inputs, outputs=[reply,sources] ) \
93
- .then(fn=lambda: gr.update(interactive=True),inputs=None, outputs=submit)
 
 
 
94
 
95
  if __name__ == "__main__":
96
  app.launch(share=True)
 
9
  from haystack.components.joiners import DocumentJoiner
10
  from haystack.components.rankers import TransformersSimilarityRanker
11
  from haystack import Pipeline,component,Document
12
+ from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator
13
+ from haystack.components.builders import ChatPromptBuilder
14
+ from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
15
+ from haystack_experimental.components.retrievers import ChatMessageRetriever
16
+ from haystack_experimental.components.writers import ChatMessageWriter
17
+ from haystack.dataclasses import ChatMessage
18
+ from itertools import chain
19
+ from typing import Any,List
20
+ from haystack.core.component.types import Variadic
21
 
22
  document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
23
  print('document_store loaded' ,document_store.count_documents())
24
 
25
+ @component
26
+ class ListJoiner:
27
+ def __init__(self, _type: Any):
28
+ component.set_output_types(self, values=_type)
29
+ def run(self, values: Variadic[Any]):
30
+ result = list(chain(*values))
31
+ return {"values": result}
32
+
33
  class Niwa_rag :
34
  def __init__(self):
35
  self.createPipe()
36
  def createPipe(self):
37
+ user_message_template = """
38
+ 会話の履歴と提供された資料に基づいて、質問に答えてください。
39
+
40
+ 会話の履歴:
41
+ {% for memory in memories %}
42
+ {{ memory.content }}
43
+ {% endfor %}
44
 
45
+ 資料:
46
  {% for document in documents %}
47
  {{ document.content }}
48
  {% endfor %}
49
 
50
+ 質問: {{query}}
51
  回答:
52
  """
53
+ system_message = ChatMessage.from_system("あなたは、提供された資料と会話履歴を使用して人間を支援するAIアシスタントです")
54
+ user_message = ChatMessage.from_user(user_message_template)
55
+ messages = [system_message, user_message]
56
+
57
  text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model)
58
  embedding_retriever = InMemoryEmbeddingRetriever(document_store)
59
  bm25_retriever = InMemoryBM25Retriever(document_store)
60
  document_joiner = DocumentJoiner()
61
+ ranker = TransformersSimilarityRanker(model=ranker_model,top_k=8)
62
+ prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
63
+ gemini = GoogleAIGeminiChatGenerator(model="models/gemini-1.0-pro")
64
+
65
+ memory_store = InMemoryChatMessageStore()
66
+ memory_joiner = ListJoiner(List[ChatMessage])
67
+ memory_retriever = ChatMessageRetriever(memory_store)
68
+ memory_writer = ChatMessageWriter(memory_store)
69
 
70
  pipe = Pipeline()
71
  pipe.add_component("text_embedder", text_embedder)
 
75
  pipe.add_component("ranker", ranker)
76
  pipe.add_component("prompt_builder", prompt_builder)
77
  pipe.add_component("llm", gemini)
78
+ pipe.add_component("memory_retriever", memory_retriever)
79
+ pipe.add_component("memory_writer", memory_writer)
80
+ pipe.add_component("memory_joiner", memory_joiner)
81
 
82
  pipe.connect("text_embedder", "embedding_retriever")
83
  pipe.connect("bm25_retriever", "document_joiner")
84
  pipe.connect("embedding_retriever", "document_joiner")
85
  pipe.connect("document_joiner", "ranker")
86
  pipe.connect("ranker.documents", "prompt_builder.documents")
87
+ pipe.connect("prompt_builder.prompt", "llm.messages")
88
+ pipe.connect("llm.replies", "memory_joiner")
89
+ pipe.connect("memory_joiner", "memory_writer")
90
+ pipe.connect("memory_retriever", "prompt_builder.memories")
91
+
92
  self.pipe = pipe
93
  def run(self,q):
94
+ print('q:',q)
95
  if not q :
96
+ return {'reply':'','sources':''}
97
+ result = self.pipe.run({
98
+ "text_embedder": {"text": q},
99
+ "bm25_retriever": {"query": q},
100
+ "ranker": {"query": q},
101
+ "prompt_builder": { "query": q},
102
+ "memory_joiner": {"values": [ChatMessage.from_user(q)]},
103
+ },include_outputs_from=["llm",'ranker'])
104
+ reply = result['llm']['replies'][0]
105
+ docs = result['ranker']['documents']
106
+ print('reply:',reply)
107
+ html = '<div class="ref-title">参考記事</div><div class="ref">'
108
+ for doc in docs :
109
  title = doc.meta['title']
110
  link = doc.meta['link']
111
+ row = f'<div><a class="link" href="{link}" target="_blank">{title}</a></div>'
112
  html += row
113
+ print('',title,link,doc.score)
114
  html += '</div>'
115
+ return {'reply':reply.content,'sources':html}
116
 
117
  rag = Niwa_rag()
118
 
119
+ def fn(q,history):
120
+ result = rag.run(q)
121
+ return result['reply'] + result['sources']
122
+
123
+ app = gr.ChatInterface(
124
+ fn,
125
+ type="messages",
126
+ textbox=gr.Textbox(placeholder='質問を記入して下さい'),
127
+ css='.ref-title{margin-top:10px} .ref .link{margin-left:20px;font-size:90%;color:-webkit-link; !important}'
128
+ )
129
+
130
 
131
  if __name__ == "__main__":
132
  app.launch(share=True)