Spaces:
Running
Running
sentense_transformers_model = "Alibaba-NLP/gte-multilingual-base" | |
ranker_model = 'Alibaba-NLP/gte-multilingual-reranker-base' | |
gemini_model = 'models/gemini-1.5-flash' | |
from haystack.document_stores.in_memory import InMemoryDocumentStore | |
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever | |
from haystack.components.embedders import SentenceTransformersTextEmbedder | |
from haystack.components.joiners import DocumentJoiner | |
from haystack.components.rankers import TransformersSimilarityRanker | |
from haystack import Pipeline,component,Document | |
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator,GoogleAIGeminiGenerator | |
from haystack.components.builders import ChatPromptBuilder,PromptBuilder | |
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore | |
from haystack_experimental.components.retrievers import ChatMessageRetriever | |
from haystack_experimental.components.writers import ChatMessageWriter | |
from haystack.dataclasses import ChatMessage | |
from itertools import chain | |
from typing import Any,List | |
from haystack.core.component.types import Variadic | |
from haystack.components.converters import OutputAdapter | |
from fugashi import Tagger | |
tagger = Tagger() | |
def gen_wakachi(text): | |
words = tagger(text) | |
return ' '.join([word.surface for word in words]) | |
class ListJoiner: | |
def __init__(self, _type: Any): | |
component.set_output_types(self, values=_type) | |
def run(self, values: Variadic[Any]): | |
result = list(chain(*values)) | |
return {"values": result} | |
class WakachiAdapter: | |
def run(self, query:str): | |
return {'wakachi':gen_wakachi(query)} | |
document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json') | |
print('document_store loaded' ,document_store.count_documents()) | |
user_message_template = """ | |
チャット履歴と提供された資料に基づいて、質問に答えてください。 | |
資料はチャット履歴の一部ではないことに注意してください。 | |
質問が資料から回答できない場合は、その旨を述べてください。 | |
チャット履歴: | |
{% for memory in memories %} | |
{{ memory.text }} | |
{% endfor %} | |
資料: | |
{% for document in documents %} | |
{{ document.content }} | |
{% endfor %} | |
質問: {{query}} | |
回答: | |
""" | |
query_rephrase_template = """ | |
意味とキーワードをそのまま維持しながら、検索用の質問を書き直してください。 | |
チャット履歴が空の場合は、クエリを変更しないでください。 | |
チャット履歴は必要な場合にのみ使用し、独自の知識でクエリを拡張することは避けてください。 | |
変更の必要がない場合は、現在の質問をそのまま出力して下さい。 | |
チャット履歴 : | |
{% for memory in memories %} | |
{{ memory.text }} | |
{% endfor %} | |
ユーザーの質問 : {{query}} | |
書き換えられた質問 : | |
""" | |
system_message = ChatMessage.from_system("あなたは、提供された資料とチャット履歴を使用して人間を支援するAIアシスタントです") | |
user_message = ChatMessage.from_user(user_message_template) | |
messages = [system_message, user_message] | |
query_rephrase_prompt_builder = PromptBuilder(query_rephrase_template) | |
query_rephrase_llm = GoogleAIGeminiGenerator(model=gemini_model) | |
list_to_str_adapter = OutputAdapter(template="{{ replies[0] }}", output_type=str) | |
wakachi_adapter = WakachiAdapter() | |
text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model,trust_remote_code=True) | |
embedding_retriever = InMemoryEmbeddingRetriever(document_store) | |
bm25_retriever = InMemoryBM25Retriever(document_store) | |
document_joiner = DocumentJoiner() | |
ranker = TransformersSimilarityRanker(model=ranker_model,meta_fields_to_embed=['title'],model_kwargs={'trust_remote_code':True}) | |
prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"]) | |
gemini = GoogleAIGeminiChatGenerator(model=gemini_model) | |
prompt_builder_emb_only = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"]) | |
gemini_emb_only = GoogleAIGeminiChatGenerator(model=gemini_model) | |
memory_store = InMemoryChatMessageStore() | |
memory_joiner = ListJoiner(List[ChatMessage]) | |
memory_retriever = ChatMessageRetriever(memory_store) | |
memory_writer = ChatMessageWriter(memory_store) | |
pipe = Pipeline() | |
pipe.add_component("query_rephrase_prompt_builder", query_rephrase_prompt_builder) | |
pipe.add_component("query_rephrase_llm", query_rephrase_llm) | |
pipe.add_component("list_to_str_adapter", list_to_str_adapter) | |
pipe.add_component("wakachi_adapter", wakachi_adapter) | |
pipe.add_component("text_embedder", text_embedder) | |
pipe.add_component("embedding_retriever", embedding_retriever) | |
pipe.add_component("bm25_retriever", bm25_retriever) | |
pipe.add_component("document_joiner", document_joiner) | |
pipe.add_component("ranker", ranker) | |
pipe.add_component("prompt_builder", prompt_builder) | |
pipe.add_component("llm", gemini) | |
pipe.add_component("memory_retriever", memory_retriever) | |
pipe.add_component("memory_joiner", memory_joiner) | |
pipe.add_component("memory_writer", memory_writer) | |
pipe.add_component("prompt_builder_emb_only", prompt_builder_emb_only) | |
pipe.add_component("llm_emb_only", gemini_emb_only) | |
pipe.connect("memory_retriever", "query_rephrase_prompt_builder.memories") | |
pipe.connect("query_rephrase_prompt_builder.prompt", "query_rephrase_llm") | |
pipe.connect("query_rephrase_llm.replies", "list_to_str_adapter") | |
pipe.connect("list_to_str_adapter", "text_embedder.text") | |
pipe.connect("list_to_str_adapter", "wakachi_adapter.query") | |
pipe.connect("wakachi_adapter.wakachi", "bm25_retriever.query") | |
pipe.connect("text_embedder", "embedding_retriever") | |
pipe.connect("embedding_retriever", "document_joiner") | |
pipe.connect("bm25_retriever", "document_joiner") | |
pipe.connect("document_joiner", "ranker") | |
pipe.connect("list_to_str_adapter", "ranker.query") | |
pipe.connect("ranker.documents", "prompt_builder.documents") | |
pipe.connect("memory_retriever", "prompt_builder.memories") | |
pipe.connect("prompt_builder.prompt", "llm.messages") | |
pipe.connect("llm.replies", "memory_joiner") | |
pipe.connect("memory_joiner", "memory_writer") | |
pipe.connect("embedding_retriever.documents", "prompt_builder_emb_only.documents") | |
pipe.connect("memory_retriever", "prompt_builder_emb_only.memories") | |
pipe.connect("prompt_builder_emb_only.prompt", "llm_emb_only.messages") |