niwa-rag / pipe.py
karubiniumu's picture
Update pipe.py
ebc1a2c verified
sentense_transformers_model = "Alibaba-NLP/gte-multilingual-base"
ranker_model = 'Alibaba-NLP/gte-multilingual-reranker-base'
gemini_model = 'models/gemini-1.5-flash'
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.joiners import DocumentJoiner
from haystack.components.rankers import TransformersSimilarityRanker
from haystack import Pipeline,component,Document
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator,GoogleAIGeminiGenerator
from haystack.components.builders import ChatPromptBuilder,PromptBuilder
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
from haystack_experimental.components.retrievers import ChatMessageRetriever
from haystack_experimental.components.writers import ChatMessageWriter
from haystack.dataclasses import ChatMessage
from itertools import chain
from typing import Any,List
from haystack.core.component.types import Variadic
from haystack.components.converters import OutputAdapter
from fugashi import Tagger
tagger = Tagger()
def gen_wakachi(text):
words = tagger(text)
return ' '.join([word.surface for word in words])
@component
class ListJoiner:
def __init__(self, _type: Any):
component.set_output_types(self, values=_type)
def run(self, values: Variadic[Any]):
result = list(chain(*values))
return {"values": result}
@component
class WakachiAdapter:
@component.output_types(wakachi=str)
def run(self, query:str):
return {'wakachi':gen_wakachi(query)}
document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
print('document_store loaded' ,document_store.count_documents())
user_message_template = """
チャット履歴と提供された資料に基づいて、質問に答えてください。
資料はチャット履歴の一部ではないことに注意してください。
質問が資料から回答できない場合は、その旨を述べてください。
チャット履歴:
{% for memory in memories %}
{{ memory.text }}
{% endfor %}
資料:
{% for document in documents %}
{{ document.content }}
{% endfor %}
質問: {{query}}
回答:
"""
query_rephrase_template = """
意味とキーワードをそのまま維持しながら、検索用の質問を書き直してください。
チャット履歴が空の場合は、クエリを変更しないでください。
チャット履歴は必要な場合にのみ使用し、独自の知識でクエリを拡張することは避けてください。
変更の必要がない場合は、現在の質問をそのまま出力して下さい。
チャット履歴 :
{% for memory in memories %}
{{ memory.text }}
{% endfor %}
ユーザーの質問 : {{query}}
書き換えられた質問 :
"""
system_message = ChatMessage.from_system("あなたは、提供された資料とチャット履歴を使用して人間を支援するAIアシスタントです")
user_message = ChatMessage.from_user(user_message_template)
messages = [system_message, user_message]
query_rephrase_prompt_builder = PromptBuilder(query_rephrase_template)
query_rephrase_llm = GoogleAIGeminiGenerator(model=gemini_model)
list_to_str_adapter = OutputAdapter(template="{{ replies[0] }}", output_type=str)
wakachi_adapter = WakachiAdapter()
text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model,trust_remote_code=True)
embedding_retriever = InMemoryEmbeddingRetriever(document_store)
bm25_retriever = InMemoryBM25Retriever(document_store)
document_joiner = DocumentJoiner()
ranker = TransformersSimilarityRanker(model=ranker_model,meta_fields_to_embed=['title'],model_kwargs={'trust_remote_code':True})
prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
gemini = GoogleAIGeminiChatGenerator(model=gemini_model)
prompt_builder_emb_only = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
gemini_emb_only = GoogleAIGeminiChatGenerator(model=gemini_model)
memory_store = InMemoryChatMessageStore()
memory_joiner = ListJoiner(List[ChatMessage])
memory_retriever = ChatMessageRetriever(memory_store)
memory_writer = ChatMessageWriter(memory_store)
pipe = Pipeline()
pipe.add_component("query_rephrase_prompt_builder", query_rephrase_prompt_builder)
pipe.add_component("query_rephrase_llm", query_rephrase_llm)
pipe.add_component("list_to_str_adapter", list_to_str_adapter)
pipe.add_component("wakachi_adapter", wakachi_adapter)
pipe.add_component("text_embedder", text_embedder)
pipe.add_component("embedding_retriever", embedding_retriever)
pipe.add_component("bm25_retriever", bm25_retriever)
pipe.add_component("document_joiner", document_joiner)
pipe.add_component("ranker", ranker)
pipe.add_component("prompt_builder", prompt_builder)
pipe.add_component("llm", gemini)
pipe.add_component("memory_retriever", memory_retriever)
pipe.add_component("memory_joiner", memory_joiner)
pipe.add_component("memory_writer", memory_writer)
pipe.add_component("prompt_builder_emb_only", prompt_builder_emb_only)
pipe.add_component("llm_emb_only", gemini_emb_only)
pipe.connect("memory_retriever", "query_rephrase_prompt_builder.memories")
pipe.connect("query_rephrase_prompt_builder.prompt", "query_rephrase_llm")
pipe.connect("query_rephrase_llm.replies", "list_to_str_adapter")
pipe.connect("list_to_str_adapter", "text_embedder.text")
pipe.connect("list_to_str_adapter", "wakachi_adapter.query")
pipe.connect("wakachi_adapter.wakachi", "bm25_retriever.query")
pipe.connect("text_embedder", "embedding_retriever")
pipe.connect("embedding_retriever", "document_joiner")
pipe.connect("bm25_retriever", "document_joiner")
pipe.connect("document_joiner", "ranker")
pipe.connect("list_to_str_adapter", "ranker.query")
pipe.connect("ranker.documents", "prompt_builder.documents")
pipe.connect("memory_retriever", "prompt_builder.memories")
pipe.connect("prompt_builder.prompt", "llm.messages")
pipe.connect("llm.replies", "memory_joiner")
pipe.connect("memory_joiner", "memory_writer")
pipe.connect("embedding_retriever.documents", "prompt_builder_emb_only.documents")
pipe.connect("memory_retriever", "prompt_builder_emb_only.memories")
pipe.connect("prompt_builder_emb_only.prompt", "llm_emb_only.messages")