import gradio as gr sentense_transformers_model = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" ranker_model = 'hotchpotch/japanese-reranker-cross-encoder-base-v1' from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever from haystack.components.embedders import SentenceTransformersTextEmbedder from haystack.components.joiners import DocumentJoiner from haystack.components.rankers import TransformersSimilarityRanker from haystack import Pipeline,component,Document from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator from haystack.components.builders import ChatPromptBuilder from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore from haystack_experimental.components.retrievers import ChatMessageRetriever from haystack_experimental.components.writers import ChatMessageWriter from haystack.dataclasses import ChatMessage from itertools import chain from typing import Any,List from haystack.core.component.types import Variadic document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json') print('document_store loaded' ,document_store.count_documents()) @component class ListJoiner: def __init__(self, _type: Any): component.set_output_types(self, values=_type) def run(self, values: Variadic[Any]): result = list(chain(*values)) return {"values": result} class Niwa_rag : def __init__(self): self.createPipe() def createPipe(self): user_message_template = """ 会話の履歴と提供された資料に基づいて、質問に答えてください。 会話の履歴: {% for memory in memories %} {{ memory.content }} {% endfor %} 資料: {% for document in documents %} {{ document.content }} {% endfor %} 質問: {{query}} 回答: """ system_message = ChatMessage.from_system("あなたは、提供された資料と会話履歴を使用して人間を支援するAIアシスタントです") user_message = ChatMessage.from_user(user_message_template) messages = [system_message, user_message] text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model) embedding_retriever = InMemoryEmbeddingRetriever(document_store) bm25_retriever = InMemoryBM25Retriever(document_store) document_joiner = DocumentJoiner() ranker = TransformersSimilarityRanker(model=ranker_model,top_k=8) prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"]) gemini = GoogleAIGeminiChatGenerator(model="models/gemini-1.0-pro") memory_store = InMemoryChatMessageStore() memory_joiner = ListJoiner(List[ChatMessage]) memory_retriever = ChatMessageRetriever(memory_store) memory_writer = ChatMessageWriter(memory_store) pipe = Pipeline() pipe.add_component("text_embedder", text_embedder) pipe.add_component("embedding_retriever", embedding_retriever) pipe.add_component("bm25_retriever", bm25_retriever) pipe.add_component("document_joiner", document_joiner) pipe.add_component("ranker", ranker) pipe.add_component("prompt_builder", prompt_builder) pipe.add_component("llm", gemini) pipe.add_component("memory_retriever", memory_retriever) pipe.add_component("memory_writer", memory_writer) pipe.add_component("memory_joiner", memory_joiner) pipe.connect("text_embedder", "embedding_retriever") pipe.connect("bm25_retriever", "document_joiner") pipe.connect("embedding_retriever", "document_joiner") pipe.connect("document_joiner", "ranker") pipe.connect("ranker.documents", "prompt_builder.documents") pipe.connect("prompt_builder.prompt", "llm.messages") pipe.connect("llm.replies", "memory_joiner") pipe.connect("memory_joiner", "memory_writer") pipe.connect("memory_retriever", "prompt_builder.memories") self.pipe = pipe def run(self,q): print('q:',q) if not q : return {'reply':'','sources':''} result = self.pipe.run({ "text_embedder": {"text": q}, "bm25_retriever": {"query": q}, "ranker": {"query": q}, "prompt_builder": { "query": q}, "memory_joiner": {"values": [ChatMessage.from_user(q)]}, },include_outputs_from=["llm",'ranker']) reply = result['llm']['replies'][0] docs = result['ranker']['documents'] print('reply:',reply) html = '
参考記事
' for doc in docs : title = doc.meta['title'] link = doc.meta['link'] row = f'
{title}
' html += row print('',title,link,doc.meta['type'],doc.score) html += '
' return {'reply':reply.content,'sources':html} rag = Niwa_rag() def fn(q,history): result = rag.run(q) return result['reply'] + result['sources'] app = gr.ChatInterface( fn, type="messages", textbox=gr.Textbox(placeholder='質問を記入して下さい',submit_btn=True), css='.ref-title{margin-top:10px} .ref .link{margin-left:20px;font-size:90%;color:-webkit-link; !important}' ) if __name__ == "__main__": app.launch(share=True)