File size: 6,701 Bytes
e2a8726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebc1a2c
e2a8726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804908b
 
e2a8726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804908b
 
e2a8726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804908b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
sentense_transformers_model = "Alibaba-NLP/gte-multilingual-base"
ranker_model = 'Alibaba-NLP/gte-multilingual-reranker-base'
gemini_model = 'models/gemini-1.5-flash'


from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.joiners import DocumentJoiner
from haystack.components.rankers import TransformersSimilarityRanker
from haystack import Pipeline,component,Document
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator,GoogleAIGeminiGenerator
from haystack.components.builders import ChatPromptBuilder,PromptBuilder
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
from haystack_experimental.components.retrievers import ChatMessageRetriever
from haystack_experimental.components.writers import ChatMessageWriter
from haystack.dataclasses import ChatMessage
from itertools import chain
from typing import Any,List
from haystack.core.component.types import Variadic
from haystack.components.converters import OutputAdapter
from fugashi import Tagger

tagger = Tagger()
def gen_wakachi(text):
    words = tagger(text)
    return ' '.join([word.surface for word in words])

@component
class ListJoiner:
    def __init__(self, _type: Any):
        component.set_output_types(self, values=_type)
    def run(self, values: Variadic[Any]):
        result = list(chain(*values))
        return {"values": result}

@component 
class WakachiAdapter:
    @component.output_types(wakachi=str)
    def run(self, query:str):
        return {'wakachi':gen_wakachi(query)}

document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
print('document_store loaded' ,document_store.count_documents())

user_message_template = """
    チャット履歴と提供された資料に基づいて、質問に答えてください。
    資料はチャット履歴の一部ではないことに注意してください。
    質問が資料から回答できない場合は、その旨を述べてください。
    
    チャット履歴:
    {% for memory in memories %}
        {{ memory.text }}
    {% endfor %}

    資料:
    {% for document in documents %}
        {{ document.content }}
    {% endfor %}

    質問: {{query}}
    回答:
"""
query_rephrase_template = """
    意味とキーワードをそのまま維持しながら、検索用の質問を書き直してください。
    チャット履歴が空の場合は、クエリを変更しないでください。
    チャット履歴は必要な場合にのみ使用し、独自の知識でクエリを拡張することは避けてください。
    変更の必要がない場合は、現在の質問をそのまま出力して下さい。

    チャット履歴 :
    {% for memory in memories %}
        {{ memory.text }}
    {% endfor %}

    ユーザーの質問 : {{query}}
    書き換えられた質問 :
"""

system_message = ChatMessage.from_system("あなたは、提供された資料とチャット履歴を使用して人間を支援するAIアシスタントです")
user_message = ChatMessage.from_user(user_message_template)
messages = [system_message, user_message]

query_rephrase_prompt_builder = PromptBuilder(query_rephrase_template)
query_rephrase_llm = GoogleAIGeminiGenerator(model=gemini_model)
list_to_str_adapter = OutputAdapter(template="{{ replies[0] }}", output_type=str)
wakachi_adapter = WakachiAdapter()

text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model,trust_remote_code=True)
embedding_retriever = InMemoryEmbeddingRetriever(document_store)
bm25_retriever = InMemoryBM25Retriever(document_store)
document_joiner = DocumentJoiner()
ranker = TransformersSimilarityRanker(model=ranker_model,meta_fields_to_embed=['title'],model_kwargs={'trust_remote_code':True})
prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
gemini = GoogleAIGeminiChatGenerator(model=gemini_model)
prompt_builder_emb_only = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
gemini_emb_only = GoogleAIGeminiChatGenerator(model=gemini_model)

memory_store = InMemoryChatMessageStore()
memory_joiner = ListJoiner(List[ChatMessage])
memory_retriever = ChatMessageRetriever(memory_store)
memory_writer = ChatMessageWriter(memory_store)

pipe = Pipeline()
pipe.add_component("query_rephrase_prompt_builder", query_rephrase_prompt_builder)
pipe.add_component("query_rephrase_llm", query_rephrase_llm)
pipe.add_component("list_to_str_adapter", list_to_str_adapter)
pipe.add_component("wakachi_adapter", wakachi_adapter)
pipe.add_component("text_embedder", text_embedder)
pipe.add_component("embedding_retriever", embedding_retriever)
pipe.add_component("bm25_retriever", bm25_retriever)
pipe.add_component("document_joiner", document_joiner)
pipe.add_component("ranker", ranker)
pipe.add_component("prompt_builder", prompt_builder)
pipe.add_component("llm", gemini)
pipe.add_component("memory_retriever", memory_retriever)
pipe.add_component("memory_joiner", memory_joiner)
pipe.add_component("memory_writer", memory_writer)
pipe.add_component("prompt_builder_emb_only", prompt_builder_emb_only)
pipe.add_component("llm_emb_only", gemini_emb_only)

pipe.connect("memory_retriever", "query_rephrase_prompt_builder.memories")
pipe.connect("query_rephrase_prompt_builder.prompt", "query_rephrase_llm")
pipe.connect("query_rephrase_llm.replies", "list_to_str_adapter")
pipe.connect("list_to_str_adapter", "text_embedder.text")
pipe.connect("list_to_str_adapter", "wakachi_adapter.query")
pipe.connect("wakachi_adapter.wakachi", "bm25_retriever.query")
pipe.connect("text_embedder", "embedding_retriever")
pipe.connect("embedding_retriever", "document_joiner")
pipe.connect("bm25_retriever", "document_joiner")
pipe.connect("document_joiner", "ranker")
pipe.connect("list_to_str_adapter", "ranker.query")
pipe.connect("ranker.documents", "prompt_builder.documents")
pipe.connect("memory_retriever", "prompt_builder.memories")
pipe.connect("prompt_builder.prompt", "llm.messages")
pipe.connect("llm.replies", "memory_joiner")
pipe.connect("memory_joiner", "memory_writer")
pipe.connect("embedding_retriever.documents", "prompt_builder_emb_only.documents")
pipe.connect("memory_retriever", "prompt_builder_emb_only.memories")
pipe.connect("prompt_builder_emb_only.prompt", "llm_emb_only.messages")