karubiniumu commited on
Commit
e2a8726
1 Parent(s): 71c7a2c

query_rephrase

Browse files
Files changed (2) hide show
  1. app.py +47 -143
  2. pipe.py +132 -0
app.py CHANGED
@@ -1,156 +1,60 @@
1
  import gradio as gr
2
- import datetime
3
  import pytz
4
-
5
- sentense_transformers_model = "Alibaba-NLP/gte-multilingual-base"
6
- ranker_model = 'Alibaba-NLP/gte-multilingual-reranker-base'
7
-
8
- from haystack.document_stores.in_memory import InMemoryDocumentStore
9
- from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
10
- from haystack.components.embedders import SentenceTransformersTextEmbedder
11
- from haystack.components.joiners import DocumentJoiner
12
- from haystack.components.rankers import TransformersSimilarityRanker
13
- from haystack import Pipeline,component,Document
14
- from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator
15
- from haystack.components.builders import ChatPromptBuilder
16
- from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
17
- from haystack_experimental.components.retrievers import ChatMessageRetriever
18
- from haystack_experimental.components.writers import ChatMessageWriter
19
  from haystack.dataclasses import ChatMessage
20
- from itertools import chain
21
- from typing import Any,List
22
- from haystack.core.component.types import Variadic
23
- from fugashi import Tagger
24
 
25
- document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
26
- print('document_store loaded' ,document_store.count_documents())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- @component
29
- class ListJoiner:
30
- def __init__(self, _type: Any):
31
- component.set_output_types(self, values=_type)
32
- def run(self, values: Variadic[Any]):
33
- result = list(chain(*values))
34
- return {"values": result}
35
-
36
- tagger = Tagger()
37
- def gen_wakachi(text):
38
- words = tagger(text)
39
- return ' '.join([word.surface for word in words])
40
-
41
- class Niwa_rag :
42
- def __init__(self):
43
- self.createPipe()
44
- def createPipe(self):
45
- user_message_template = """
46
- チャット履歴と提供された資料に基づいて、質問に答えてください。
47
- 資料はチャット履歴の一部ではないことに注意してください。
48
- 質問が資料から回答できない場合は、その旨を述べてください。
49
-
50
- チャット履歴:
51
- {% for memory in memories %}
52
- {{ memory.content }}
53
- {% endfor %}
54
-
55
- 資料:
56
- {% for document in documents %}
57
- {{ document.content }}
58
- {% endfor %}
59
-
60
- 質問: {{query}}
61
- 回答:
62
- """
63
- system_message = ChatMessage.from_system("あなたは、提供された資料とチャット履歴を使用して人間を支援するAIアシスタントです")
64
- user_message = ChatMessage.from_user(user_message_template)
65
- messages = [system_message, user_message]
66
 
67
- text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model,trust_remote_code=True)
68
- embedding_retriever = InMemoryEmbeddingRetriever(document_store)
69
- bm25_retriever = InMemoryBM25Retriever(document_store)
70
- document_joiner = DocumentJoiner()
71
- ranker = TransformersSimilarityRanker(model=ranker_model,meta_fields_to_embed=['title'],model_kwargs={'trust_remote_code':True})
72
- prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
73
- gemini = GoogleAIGeminiChatGenerator(model="models/gemini-1.5-flash")
74
-
75
- memory_store = InMemoryChatMessageStore()
76
- memory_joiner = ListJoiner(List[ChatMessage])
77
- memory_retriever = ChatMessageRetriever(memory_store)
78
- memory_writer = ChatMessageWriter(memory_store)
79
-
80
- pipe = Pipeline()
81
- pipe.add_component("text_embedder", text_embedder)
82
- pipe.add_component("embedding_retriever", embedding_retriever)
83
- pipe.add_component("bm25_retriever", bm25_retriever)
84
- pipe.add_component("document_joiner", document_joiner)
85
- pipe.add_component("ranker", ranker)
86
- pipe.add_component("prompt_builder", prompt_builder)
87
- pipe.add_component("llm", gemini)
88
- pipe.add_component("memory_retriever", memory_retriever)
89
- pipe.add_component("memory_writer", memory_writer)
90
- pipe.add_component("memory_joiner", memory_joiner)
91
-
92
- pipe.connect("text_embedder", "embedding_retriever")
93
- pipe.connect("bm25_retriever", "document_joiner")
94
- pipe.connect("embedding_retriever", "document_joiner")
95
- pipe.connect("document_joiner", "ranker")
96
- pipe.connect("ranker.documents", "prompt_builder.documents")
97
- pipe.connect("prompt_builder.prompt", "llm.messages")
98
- pipe.connect("llm.replies", "memory_joiner")
99
- pipe.connect("memory_joiner", "memory_writer")
100
- pipe.connect("memory_retriever", "prompt_builder.memories")
101
-
102
- self.pipe = pipe
103
- def run(self,q):
104
- now = datetime.datetime.now(pytz.timezone('Asia/Tokyo'))
105
- print('\nq:',q,now)
106
- if not q :
107
- return {'reply':'','sources':''}
108
-
109
- result = self.pipe.run({
110
- "text_embedder": {"text": q},
111
- "bm25_retriever": {"query": gen_wakachi(q)},
112
- "ranker": {"query": q},
113
- "prompt_builder": { "query": q},
114
- "memory_joiner": {"values": [ChatMessage.from_user(q)]},
115
- },include_outputs_from=["llm",'ranker','bm25_retriever','embedding_retriever'])
116
-
117
- def log_document(doc):
118
- title = doc.meta['title']
119
- link = doc.meta['link']
120
- print('',title,link,doc.meta['type'],doc.score)
121
-
122
- for retriever in ['bm25_retriever','embedding_retriever'] :
123
- print(retriever)
124
- docs = result[retriever]['documents']
125
- for doc in docs :
126
- log_document(doc)
127
-
128
- reply = result['llm']['replies'][0]
129
- docs = result['ranker']['documents']
130
- print('reply:',reply)
131
- print('ranker')
132
- def get_unique_docs(docs):
133
- source_ids = set([doc.meta['source_id'] for doc in docs])
134
- _docs = sorted([[doc for doc in docs if doc.meta['source_id']==source_id][0] for source_id in source_ids],key=lambda x:x.score,reverse=True)
135
- return _docs
136
- html = '<div class="ref-title">参考記事</div><div class="ref">'
137
- for doc in get_unique_docs(docs) :
138
- log_document(doc)
139
- title = doc.meta['title']
140
- link = doc.meta['link']
141
- row = f'<div><a class="link" href="{link}" target="_blank">{title}</a></div>'
142
- html += row
143
- html += '</div>'
144
- return {'reply':reply.content,'sources':html}
145
-
146
- rag = Niwa_rag()
147
 
148
- def fn(q,history):
149
- result = rag.run(q)
150
  return result['reply'] + result['sources']
151
 
152
  app = gr.ChatInterface(
153
- fn,
154
  type="messages",
155
  title='庭ファン Chatbot',
156
  textbox=gr.Textbox(placeholder='質問を記入して下さい',submit_btn=True),
 
1
  import gradio as gr
2
+ from pipe import pipe
3
  import pytz
4
+ import datetime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from haystack.dataclasses import ChatMessage
 
 
 
 
6
 
7
+ def log_docs(docs):
8
+ for doc in docs :
9
+ title = doc.meta['title']
10
+ link = doc.meta['link']
11
+ print('',title,link,doc.meta['type'],doc.score)
12
+ def get_unique_docs(docs):
13
+ source_ids = set([doc.meta['source_id'] for doc in docs])
14
+ _docs = sorted([[doc for doc in docs if doc.meta['source_id']==source_id][0] for source_id in source_ids],key=lambda x:x.score,reverse=True)
15
+ return _docs
16
+
17
+ def run(pipe,q):
18
+ now = datetime.datetime.now(pytz.timezone('Asia/Tokyo'))
19
+ print('\nq:',q,now)
20
+ if not q :
21
+ return {'reply':'','sources':''}
22
+
23
+ result = pipe.run({
24
+ "query_rephrase_prompt_builder":{"query":q},
25
+ "prompt_builder": { "query": q},
26
+ "memory_joiner": {"values": [ChatMessage.from_user(q)]},
27
+ },include_outputs_from=["query_rephrase_llm","llm",'ranker','bm25_retriever','embedding_retriever'])
28
 
29
+ query_rephrase = result['query_rephrase_llm']['replies'][0]
30
+ print('query_rephrase:',query_rephrase)
31
+
32
+ for retriever in ['bm25_retriever','embedding_retriever'] :
33
+ print(retriever)
34
+ docs = result[retriever]['documents']
35
+ log_docs(docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ reply = result['llm']['replies'][0]
38
+ docs = result['ranker']['documents']
39
+ print('reply:',reply.content)
40
+ print('ranker')
41
+ log_docs(docs)
42
+
43
+ html = '<div class="ref-title">参考記事</div><div class="ref">'
44
+ for doc in get_unique_docs(docs) :
45
+ title = doc.meta['title']
46
+ link = doc.meta['link']
47
+ row = f'<div><a class="link" href="{link}" target="_blank">{title}</a></div>'
48
+ html += row
49
+ html += '</div>'
50
+ return {'reply':reply.content,'sources':html}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ def rag(q,history):
53
+ result = run(pipe,q)
54
  return result['reply'] + result['sources']
55
 
56
  app = gr.ChatInterface(
57
+ rag,
58
  type="messages",
59
  title='庭ファン Chatbot',
60
  textbox=gr.Textbox(placeholder='質問を記入して下さい',submit_btn=True),
pipe.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sentense_transformers_model = "Alibaba-NLP/gte-multilingual-base"
2
+ ranker_model = 'Alibaba-NLP/gte-multilingual-reranker-base'
3
+ gemini_model = 'models/gemini-1.5-flash'
4
+
5
+
6
+ from haystack.document_stores.in_memory import InMemoryDocumentStore
7
+ from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
8
+ from haystack.components.embedders import SentenceTransformersTextEmbedder
9
+ from haystack.components.joiners import DocumentJoiner
10
+ from haystack.components.rankers import TransformersSimilarityRanker
11
+ from haystack import Pipeline,component,Document
12
+ from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator,GoogleAIGeminiGenerator
13
+ from haystack.components.builders import ChatPromptBuilder,PromptBuilder
14
+ from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
15
+ from haystack_experimental.components.retrievers import ChatMessageRetriever
16
+ from haystack_experimental.components.writers import ChatMessageWriter
17
+ from haystack.dataclasses import ChatMessage
18
+ from itertools import chain
19
+ from typing import Any,List
20
+ from haystack.core.component.types import Variadic
21
+ from haystack.components.converters import OutputAdapter
22
+ from fugashi import Tagger
23
+
24
+ tagger = Tagger()
25
+ def gen_wakachi(text):
26
+ words = tagger(text)
27
+ return ' '.join([word.surface for word in words])
28
+
29
+ @component
30
+ class ListJoiner:
31
+ def __init__(self, _type: Any):
32
+ component.set_output_types(self, values=_type)
33
+ def run(self, values: Variadic[Any]):
34
+ result = list(chain(*values))
35
+ return {"values": result}
36
+
37
+ @component
38
+ class WakachiAdapter:
39
+ @component.output_types(wakachi=str)
40
+ def run(self, query:str):
41
+ return {'wakachi':gen_wakachi(query)}
42
+
43
+ document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
44
+ print('document_store loaded' ,document_store.count_documents())
45
+
46
+ user_message_template = """
47
+ チャット履歴と提供された資料に基づいて、質問に答えてください。
48
+ 資料はチャット履歴の一部ではないことに注意してください。
49
+ 質問が資料から回答できない場合は、その旨を述べてください。
50
+
51
+ チャット履歴:
52
+ {% for memory in memories %}
53
+ {{ memory.content }}
54
+ {% endfor %}
55
+
56
+ 資料:
57
+ {% for document in documents %}
58
+ {{ document.content }}
59
+ {% endfor %}
60
+
61
+ 質問: {{query}}
62
+ 回答:
63
+ """
64
+ query_rephrase_template = """
65
+ 意味とキーワードをそのまま維持しながら、検索用の質問を書き直してください。
66
+ チャット履歴が空の場合は、クエリを変更しないでください。
67
+ チャット履歴は必要な場合にのみ使用し、独自の知識でクエリを拡張することは避けてください。
68
+ 変更の必要がない場合は、現在の質問をそのまま出力して下さい。
69
+
70
+ チャット履歴 :
71
+ {% for memory in memories %}
72
+ {{ memory.text }}
73
+ {% endfor %}
74
+
75
+ ユーザーの質問 : {{query}}
76
+ 書き換えられた質問 :
77
+ """
78
+
79
+ system_message = ChatMessage.from_system("あなたは、提供された資料とチャット履歴を使用して人間を支援するAIアシスタントです")
80
+ user_message = ChatMessage.from_user(user_message_template)
81
+ messages = [system_message, user_message]
82
+
83
+ query_rephrase_prompt_builder = PromptBuilder(query_rephrase_template)
84
+ query_rephrase_llm = GoogleAIGeminiGenerator(model=gemini_model)
85
+ list_to_str_adapter = OutputAdapter(template="{{ replies[0] }}", output_type=str)
86
+ wakachi_adapter = WakachiAdapter()
87
+
88
+ text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model,trust_remote_code=True)
89
+ embedding_retriever = InMemoryEmbeddingRetriever(document_store)
90
+ bm25_retriever = InMemoryBM25Retriever(document_store)
91
+ document_joiner = DocumentJoiner()
92
+ ranker = TransformersSimilarityRanker(model=ranker_model,meta_fields_to_embed=['title'],model_kwargs={'trust_remote_code':True})
93
+ prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
94
+ gemini = GoogleAIGeminiChatGenerator(model=gemini_model)
95
+
96
+ memory_store = InMemoryChatMessageStore()
97
+ memory_joiner = ListJoiner(List[ChatMessage])
98
+ memory_retriever = ChatMessageRetriever(memory_store)
99
+ memory_writer = ChatMessageWriter(memory_store)
100
+
101
+ pipe = Pipeline()
102
+ pipe.add_component("query_rephrase_prompt_builder", query_rephrase_prompt_builder)
103
+ pipe.add_component("query_rephrase_llm", query_rephrase_llm)
104
+ pipe.add_component("list_to_str_adapter", list_to_str_adapter)
105
+ pipe.add_component("wakachi_adapter", wakachi_adapter)
106
+ pipe.add_component("text_embedder", text_embedder)
107
+ pipe.add_component("embedding_retriever", embedding_retriever)
108
+ pipe.add_component("bm25_retriever", bm25_retriever)
109
+ pipe.add_component("document_joiner", document_joiner)
110
+ pipe.add_component("ranker", ranker)
111
+ pipe.add_component("prompt_builder", prompt_builder)
112
+ pipe.add_component("llm", gemini)
113
+ pipe.add_component("memory_retriever", memory_retriever)
114
+ pipe.add_component("memory_joiner", memory_joiner)
115
+ pipe.add_component("memory_writer", memory_writer)
116
+
117
+ pipe.connect("memory_retriever", "query_rephrase_prompt_builder.memories")
118
+ pipe.connect("query_rephrase_prompt_builder.prompt", "query_rephrase_llm")
119
+ pipe.connect("query_rephrase_llm.replies", "list_to_str_adapter")
120
+ pipe.connect("list_to_str_adapter", "text_embedder.text")
121
+ pipe.connect("list_to_str_adapter", "wakachi_adapter.query")
122
+ pipe.connect("wakachi_adapter.wakachi", "bm25_retriever.query")
123
+ pipe.connect("text_embedder", "embedding_retriever")
124
+ pipe.connect("embedding_retriever", "document_joiner")
125
+ pipe.connect("bm25_retriever", "document_joiner")
126
+ pipe.connect("document_joiner", "ranker")
127
+ pipe.connect("list_to_str_adapter", "ranker.query")
128
+ pipe.connect("ranker.documents", "prompt_builder.documents")
129
+ pipe.connect("memory_retriever", "prompt_builder.memories")
130
+ pipe.connect("prompt_builder.prompt", "llm.messages")
131
+ pipe.connect("llm.replies", "memory_joiner")
132
+ pipe.connect("memory_joiner", "memory_writer")