Spaces:
Running
Running
import gradio as gr | |
sentense_transformers_model = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | |
ranker_model = 'hotchpotch/japanese-reranker-cross-encoder-base-v1' | |
from haystack.document_stores.in_memory import InMemoryDocumentStore | |
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever | |
from haystack.components.embedders import SentenceTransformersTextEmbedder | |
from haystack.components.joiners import DocumentJoiner | |
from haystack.components.rankers import TransformersSimilarityRanker | |
from haystack import Pipeline,component,Document | |
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator | |
from haystack.components.builders import ChatPromptBuilder | |
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore | |
from haystack_experimental.components.retrievers import ChatMessageRetriever | |
from haystack_experimental.components.writers import ChatMessageWriter | |
from haystack.dataclasses import ChatMessage | |
from itertools import chain | |
from typing import Any,List | |
from haystack.core.component.types import Variadic | |
document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json') | |
print('document_store loaded' ,document_store.count_documents()) | |
class ListJoiner: | |
def __init__(self, _type: Any): | |
component.set_output_types(self, values=_type) | |
def run(self, values: Variadic[Any]): | |
result = list(chain(*values)) | |
return {"values": result} | |
class Niwa_rag : | |
def __init__(self): | |
self.createPipe() | |
def createPipe(self): | |
user_message_template = """ | |
会話の履歴と提供された資料に基づいて、質問に答えてください。 | |
会話の履歴: | |
{% for memory in memories %} | |
{{ memory.content }} | |
{% endfor %} | |
資料: | |
{% for document in documents %} | |
{{ document.content }} | |
{% endfor %} | |
質問: {{query}} | |
回答: | |
""" | |
system_message = ChatMessage.from_system("あなたは、提供された資料と会話履歴を使用して人間を支援するAIアシスタントです") | |
user_message = ChatMessage.from_user(user_message_template) | |
messages = [system_message, user_message] | |
text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model) | |
embedding_retriever = InMemoryEmbeddingRetriever(document_store) | |
bm25_retriever = InMemoryBM25Retriever(document_store) | |
document_joiner = DocumentJoiner() | |
ranker = TransformersSimilarityRanker(model=ranker_model,top_k=8) | |
prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"]) | |
gemini = GoogleAIGeminiChatGenerator(model="models/gemini-1.0-pro") | |
memory_store = InMemoryChatMessageStore() | |
memory_joiner = ListJoiner(List[ChatMessage]) | |
memory_retriever = ChatMessageRetriever(memory_store) | |
memory_writer = ChatMessageWriter(memory_store) | |
pipe = Pipeline() | |
pipe.add_component("text_embedder", text_embedder) | |
pipe.add_component("embedding_retriever", embedding_retriever) | |
pipe.add_component("bm25_retriever", bm25_retriever) | |
pipe.add_component("document_joiner", document_joiner) | |
pipe.add_component("ranker", ranker) | |
pipe.add_component("prompt_builder", prompt_builder) | |
pipe.add_component("llm", gemini) | |
pipe.add_component("memory_retriever", memory_retriever) | |
pipe.add_component("memory_writer", memory_writer) | |
pipe.add_component("memory_joiner", memory_joiner) | |
pipe.connect("text_embedder", "embedding_retriever") | |
pipe.connect("bm25_retriever", "document_joiner") | |
pipe.connect("embedding_retriever", "document_joiner") | |
pipe.connect("document_joiner", "ranker") | |
pipe.connect("ranker.documents", "prompt_builder.documents") | |
pipe.connect("prompt_builder.prompt", "llm.messages") | |
pipe.connect("llm.replies", "memory_joiner") | |
pipe.connect("memory_joiner", "memory_writer") | |
pipe.connect("memory_retriever", "prompt_builder.memories") | |
self.pipe = pipe | |
def run(self,q): | |
print('q:',q) | |
if not q : | |
return {'reply':'','sources':''} | |
result = self.pipe.run({ | |
"text_embedder": {"text": q}, | |
"bm25_retriever": {"query": q}, | |
"ranker": {"query": q}, | |
"prompt_builder": { "query": q}, | |
"memory_joiner": {"values": [ChatMessage.from_user(q)]}, | |
},include_outputs_from=["llm",'ranker']) | |
reply = result['llm']['replies'][0] | |
docs = result['ranker']['documents'] | |
print('reply:',reply) | |
html = '<div class="ref-title">参考記事</div><div class="ref">' | |
for doc in docs : | |
title = doc.meta['title'] | |
link = doc.meta['link'] | |
row = f'<div><a class="link" href="{link}" target="_blank">{title}</a></div>' | |
html += row | |
print('',title,link,doc.meta['type'],doc.score) | |
html += '</div>' | |
return {'reply':reply.content,'sources':html} | |
rag = Niwa_rag() | |
def fn(q,history): | |
result = rag.run(q) | |
return result['reply'] + result['sources'] | |
app = gr.ChatInterface( | |
fn, | |
type="messages", | |
textbox=gr.Textbox(placeholder='質問を記入して下さい',submit_btn=True), | |
css='.ref-title{margin-top:10px} .ref .link{margin-left:20px;font-size:90%;color:-webkit-link; !important}' | |
) | |
if __name__ == "__main__": | |
app.launch(share=True) |