Spaces:
Running
Running
File size: 6,701 Bytes
e2a8726 ebc1a2c e2a8726 804908b e2a8726 804908b e2a8726 804908b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
sentense_transformers_model = "Alibaba-NLP/gte-multilingual-base"
ranker_model = 'Alibaba-NLP/gte-multilingual-reranker-base'
gemini_model = 'models/gemini-1.5-flash'
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.joiners import DocumentJoiner
from haystack.components.rankers import TransformersSimilarityRanker
from haystack import Pipeline,component,Document
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator,GoogleAIGeminiGenerator
from haystack.components.builders import ChatPromptBuilder,PromptBuilder
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
from haystack_experimental.components.retrievers import ChatMessageRetriever
from haystack_experimental.components.writers import ChatMessageWriter
from haystack.dataclasses import ChatMessage
from itertools import chain
from typing import Any,List
from haystack.core.component.types import Variadic
from haystack.components.converters import OutputAdapter
from fugashi import Tagger
tagger = Tagger()
def gen_wakachi(text):
words = tagger(text)
return ' '.join([word.surface for word in words])
@component
class ListJoiner:
def __init__(self, _type: Any):
component.set_output_types(self, values=_type)
def run(self, values: Variadic[Any]):
result = list(chain(*values))
return {"values": result}
@component
class WakachiAdapter:
@component.output_types(wakachi=str)
def run(self, query:str):
return {'wakachi':gen_wakachi(query)}
document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
print('document_store loaded' ,document_store.count_documents())
user_message_template = """
チャット履歴と提供された資料に基づいて、質問に答えてください。
資料はチャット履歴の一部ではないことに注意してください。
質問が資料から回答できない場合は、その旨を述べてください。
チャット履歴:
{% for memory in memories %}
{{ memory.text }}
{% endfor %}
資料:
{% for document in documents %}
{{ document.content }}
{% endfor %}
質問: {{query}}
回答:
"""
query_rephrase_template = """
意味とキーワードをそのまま維持しながら、検索用の質問を書き直してください。
チャット履歴が空の場合は、クエリを変更しないでください。
チャット履歴は必要な場合にのみ使用し、独自の知識でクエリを拡張することは避けてください。
変更の必要がない場合は、現在の質問をそのまま出力して下さい。
チャット履歴 :
{% for memory in memories %}
{{ memory.text }}
{% endfor %}
ユーザーの質問 : {{query}}
書き換えられた質問 :
"""
system_message = ChatMessage.from_system("あなたは、提供された資料とチャット履歴を使用して人間を支援するAIアシスタントです")
user_message = ChatMessage.from_user(user_message_template)
messages = [system_message, user_message]
query_rephrase_prompt_builder = PromptBuilder(query_rephrase_template)
query_rephrase_llm = GoogleAIGeminiGenerator(model=gemini_model)
list_to_str_adapter = OutputAdapter(template="{{ replies[0] }}", output_type=str)
wakachi_adapter = WakachiAdapter()
text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model,trust_remote_code=True)
embedding_retriever = InMemoryEmbeddingRetriever(document_store)
bm25_retriever = InMemoryBM25Retriever(document_store)
document_joiner = DocumentJoiner()
ranker = TransformersSimilarityRanker(model=ranker_model,meta_fields_to_embed=['title'],model_kwargs={'trust_remote_code':True})
prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
gemini = GoogleAIGeminiChatGenerator(model=gemini_model)
prompt_builder_emb_only = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
gemini_emb_only = GoogleAIGeminiChatGenerator(model=gemini_model)
memory_store = InMemoryChatMessageStore()
memory_joiner = ListJoiner(List[ChatMessage])
memory_retriever = ChatMessageRetriever(memory_store)
memory_writer = ChatMessageWriter(memory_store)
pipe = Pipeline()
pipe.add_component("query_rephrase_prompt_builder", query_rephrase_prompt_builder)
pipe.add_component("query_rephrase_llm", query_rephrase_llm)
pipe.add_component("list_to_str_adapter", list_to_str_adapter)
pipe.add_component("wakachi_adapter", wakachi_adapter)
pipe.add_component("text_embedder", text_embedder)
pipe.add_component("embedding_retriever", embedding_retriever)
pipe.add_component("bm25_retriever", bm25_retriever)
pipe.add_component("document_joiner", document_joiner)
pipe.add_component("ranker", ranker)
pipe.add_component("prompt_builder", prompt_builder)
pipe.add_component("llm", gemini)
pipe.add_component("memory_retriever", memory_retriever)
pipe.add_component("memory_joiner", memory_joiner)
pipe.add_component("memory_writer", memory_writer)
pipe.add_component("prompt_builder_emb_only", prompt_builder_emb_only)
pipe.add_component("llm_emb_only", gemini_emb_only)
pipe.connect("memory_retriever", "query_rephrase_prompt_builder.memories")
pipe.connect("query_rephrase_prompt_builder.prompt", "query_rephrase_llm")
pipe.connect("query_rephrase_llm.replies", "list_to_str_adapter")
pipe.connect("list_to_str_adapter", "text_embedder.text")
pipe.connect("list_to_str_adapter", "wakachi_adapter.query")
pipe.connect("wakachi_adapter.wakachi", "bm25_retriever.query")
pipe.connect("text_embedder", "embedding_retriever")
pipe.connect("embedding_retriever", "document_joiner")
pipe.connect("bm25_retriever", "document_joiner")
pipe.connect("document_joiner", "ranker")
pipe.connect("list_to_str_adapter", "ranker.query")
pipe.connect("ranker.documents", "prompt_builder.documents")
pipe.connect("memory_retriever", "prompt_builder.memories")
pipe.connect("prompt_builder.prompt", "llm.messages")
pipe.connect("llm.replies", "memory_joiner")
pipe.connect("memory_joiner", "memory_writer")
pipe.connect("embedding_retriever.documents", "prompt_builder_emb_only.documents")
pipe.connect("memory_retriever", "prompt_builder_emb_only.memories")
pipe.connect("prompt_builder_emb_only.prompt", "llm_emb_only.messages") |