Spaces:
Running
Running
karubiniumu
commited on
Commit
•
57957da
1
Parent(s):
b335705
chat
Browse files
app.py
CHANGED
@@ -9,42 +9,63 @@ from haystack.components.embedders import SentenceTransformersTextEmbedder
|
|
9 |
from haystack.components.joiners import DocumentJoiner
|
10 |
from haystack.components.rankers import TransformersSimilarityRanker
|
11 |
from haystack import Pipeline,component,Document
|
12 |
-
from haystack_integrations.components.generators.google_ai import
|
13 |
-
from haystack.components.builders import
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
|
22 |
document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
|
23 |
print('document_store loaded' ,document_store.count_documents())
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
class Niwa_rag :
|
26 |
def __init__(self):
|
27 |
self.createPipe()
|
28 |
def createPipe(self):
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
|
33 |
{% for document in documents %}
|
34 |
{{ document.content }}
|
35 |
{% endfor %}
|
36 |
|
37 |
-
質問: {{
|
38 |
回答:
|
39 |
"""
|
|
|
|
|
|
|
|
|
40 |
text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model)
|
41 |
embedding_retriever = InMemoryEmbeddingRetriever(document_store)
|
42 |
bm25_retriever = InMemoryBM25Retriever(document_store)
|
43 |
document_joiner = DocumentJoiner()
|
44 |
-
ranker = TransformersSimilarityRanker(model=ranker_model,top_k=
|
45 |
-
prompt_builder =
|
46 |
-
gemini =
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
|
49 |
pipe = Pipeline()
|
50 |
pipe.add_component("text_embedder", text_embedder)
|
@@ -54,43 +75,58 @@ class Niwa_rag :
|
|
54 |
pipe.add_component("ranker", ranker)
|
55 |
pipe.add_component("prompt_builder", prompt_builder)
|
56 |
pipe.add_component("llm", gemini)
|
57 |
-
pipe.add_component("
|
|
|
|
|
58 |
|
59 |
pipe.connect("text_embedder", "embedding_retriever")
|
60 |
pipe.connect("bm25_retriever", "document_joiner")
|
61 |
pipe.connect("embedding_retriever", "document_joiner")
|
62 |
pipe.connect("document_joiner", "ranker")
|
63 |
pipe.connect("ranker.documents", "prompt_builder.documents")
|
64 |
-
pipe.connect("prompt_builder.prompt", "llm")
|
65 |
-
pipe.connect("llm.replies", "
|
66 |
-
pipe.connect("
|
|
|
|
|
67 |
self.pipe = pipe
|
68 |
def run(self,q):
|
|
|
69 |
if not q :
|
70 |
-
return
|
71 |
-
result = self.pipe.run(
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
title = doc.meta['title']
|
78 |
link = doc.meta['link']
|
79 |
-
row = f'<div><a href="{link}" target="_blank">{title}</a></div>'
|
80 |
html += row
|
|
|
81 |
html += '</div>'
|
82 |
-
return
|
83 |
|
84 |
rag = Niwa_rag()
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
94 |
|
95 |
if __name__ == "__main__":
|
96 |
app.launch(share=True)
|
|
|
9 |
from haystack.components.joiners import DocumentJoiner
|
10 |
from haystack.components.rankers import TransformersSimilarityRanker
|
11 |
from haystack import Pipeline,component,Document
|
12 |
+
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator
|
13 |
+
from haystack.components.builders import ChatPromptBuilder
|
14 |
+
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
|
15 |
+
from haystack_experimental.components.retrievers import ChatMessageRetriever
|
16 |
+
from haystack_experimental.components.writers import ChatMessageWriter
|
17 |
+
from haystack.dataclasses import ChatMessage
|
18 |
+
from itertools import chain
|
19 |
+
from typing import Any,List
|
20 |
+
from haystack.core.component.types import Variadic
|
21 |
|
22 |
document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
|
23 |
print('document_store loaded' ,document_store.count_documents())
|
24 |
|
25 |
+
@component
|
26 |
+
class ListJoiner:
|
27 |
+
def __init__(self, _type: Any):
|
28 |
+
component.set_output_types(self, values=_type)
|
29 |
+
def run(self, values: Variadic[Any]):
|
30 |
+
result = list(chain(*values))
|
31 |
+
return {"values": result}
|
32 |
+
|
33 |
class Niwa_rag :
|
34 |
def __init__(self):
|
35 |
self.createPipe()
|
36 |
def createPipe(self):
|
37 |
+
user_message_template = """
|
38 |
+
会話の履歴と提供された資料に基づいて、質問に答えてください。
|
39 |
+
|
40 |
+
会話の履歴:
|
41 |
+
{% for memory in memories %}
|
42 |
+
{{ memory.content }}
|
43 |
+
{% endfor %}
|
44 |
|
45 |
+
資料:
|
46 |
{% for document in documents %}
|
47 |
{{ document.content }}
|
48 |
{% endfor %}
|
49 |
|
50 |
+
質問: {{query}}
|
51 |
回答:
|
52 |
"""
|
53 |
+
system_message = ChatMessage.from_system("あなたは、提供された資料と会話履歴を使用して人間を支援するAIアシスタントです")
|
54 |
+
user_message = ChatMessage.from_user(user_message_template)
|
55 |
+
messages = [system_message, user_message]
|
56 |
+
|
57 |
text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model)
|
58 |
embedding_retriever = InMemoryEmbeddingRetriever(document_store)
|
59 |
bm25_retriever = InMemoryBM25Retriever(document_store)
|
60 |
document_joiner = DocumentJoiner()
|
61 |
+
ranker = TransformersSimilarityRanker(model=ranker_model,top_k=8)
|
62 |
+
prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
|
63 |
+
gemini = GoogleAIGeminiChatGenerator(model="models/gemini-1.0-pro")
|
64 |
+
|
65 |
+
memory_store = InMemoryChatMessageStore()
|
66 |
+
memory_joiner = ListJoiner(List[ChatMessage])
|
67 |
+
memory_retriever = ChatMessageRetriever(memory_store)
|
68 |
+
memory_writer = ChatMessageWriter(memory_store)
|
69 |
|
70 |
pipe = Pipeline()
|
71 |
pipe.add_component("text_embedder", text_embedder)
|
|
|
75 |
pipe.add_component("ranker", ranker)
|
76 |
pipe.add_component("prompt_builder", prompt_builder)
|
77 |
pipe.add_component("llm", gemini)
|
78 |
+
pipe.add_component("memory_retriever", memory_retriever)
|
79 |
+
pipe.add_component("memory_writer", memory_writer)
|
80 |
+
pipe.add_component("memory_joiner", memory_joiner)
|
81 |
|
82 |
pipe.connect("text_embedder", "embedding_retriever")
|
83 |
pipe.connect("bm25_retriever", "document_joiner")
|
84 |
pipe.connect("embedding_retriever", "document_joiner")
|
85 |
pipe.connect("document_joiner", "ranker")
|
86 |
pipe.connect("ranker.documents", "prompt_builder.documents")
|
87 |
+
pipe.connect("prompt_builder.prompt", "llm.messages")
|
88 |
+
pipe.connect("llm.replies", "memory_joiner")
|
89 |
+
pipe.connect("memory_joiner", "memory_writer")
|
90 |
+
pipe.connect("memory_retriever", "prompt_builder.memories")
|
91 |
+
|
92 |
self.pipe = pipe
|
93 |
def run(self,q):
|
94 |
+
print('q:',q)
|
95 |
if not q :
|
96 |
+
return {'reply':'','sources':''}
|
97 |
+
result = self.pipe.run({
|
98 |
+
"text_embedder": {"text": q},
|
99 |
+
"bm25_retriever": {"query": q},
|
100 |
+
"ranker": {"query": q},
|
101 |
+
"prompt_builder": { "query": q},
|
102 |
+
"memory_joiner": {"values": [ChatMessage.from_user(q)]},
|
103 |
+
},include_outputs_from=["llm",'ranker'])
|
104 |
+
reply = result['llm']['replies'][0]
|
105 |
+
docs = result['ranker']['documents']
|
106 |
+
print('reply:',reply)
|
107 |
+
html = '<div class="ref-title">参考記事</div><div class="ref">'
|
108 |
+
for doc in docs :
|
109 |
title = doc.meta['title']
|
110 |
link = doc.meta['link']
|
111 |
+
row = f'<div><a class="link" href="{link}" target="_blank">{title}</a></div>'
|
112 |
html += row
|
113 |
+
print('',title,link,doc.score)
|
114 |
html += '</div>'
|
115 |
+
return {'reply':reply.content,'sources':html}
|
116 |
|
117 |
rag = Niwa_rag()
|
118 |
|
119 |
+
def fn(q,history):
|
120 |
+
result = rag.run(q)
|
121 |
+
return result['reply'] + result['sources']
|
122 |
+
|
123 |
+
app = gr.ChatInterface(
|
124 |
+
fn,
|
125 |
+
type="messages",
|
126 |
+
textbox=gr.Textbox(placeholder='質問を記入して下さい'),
|
127 |
+
css='.ref-title{margin-top:10px} .ref .link{margin-left:20px;font-size:90%;color:-webkit-link; !important}'
|
128 |
+
)
|
129 |
+
|
130 |
|
131 |
if __name__ == "__main__":
|
132 |
app.launch(share=True)
|