niwa-rag / app.py
karubiniumu's picture
gemini-1.5-flash
7068013
raw
history blame
6.06 kB
import gradio as gr
import datetime
import pytz
sentense_transformers_model = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
ranker_model = 'hotchpotch/japanese-reranker-cross-encoder-base-v1'
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.joiners import DocumentJoiner
from haystack.components.rankers import TransformersSimilarityRanker
from haystack import Pipeline,component,Document
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator
from haystack.components.builders import ChatPromptBuilder
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
from haystack_experimental.components.retrievers import ChatMessageRetriever
from haystack_experimental.components.writers import ChatMessageWriter
from haystack.dataclasses import ChatMessage
from itertools import chain
from typing import Any,List
from haystack.core.component.types import Variadic
document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
print('document_store loaded' ,document_store.count_documents())
@component
class ListJoiner:
def __init__(self, _type: Any):
component.set_output_types(self, values=_type)
def run(self, values: Variadic[Any]):
result = list(chain(*values))
return {"values": result}
class Niwa_rag :
def __init__(self):
self.createPipe()
def createPipe(self):
user_message_template = """
チャット履歴と提供された資料に基づいて、質問に答えてください。
資料はチャット履歴の一部ではないことに注意してください。
質問が資料から回答できない場合は、その旨を述べてください。
チャット履歴:
{% for memory in memories %}
{{ memory.content }}
{% endfor %}
資料:
{% for document in documents %}
{{ document.content }}
{% endfor %}
質問: {{query}}
回答:
"""
system_message = ChatMessage.from_system("あなたは、提供された資料とチャット履歴を使用して人間を支援するAIアシスタントです")
user_message = ChatMessage.from_user(user_message_template)
messages = [system_message, user_message]
text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model)
embedding_retriever = InMemoryEmbeddingRetriever(document_store)
bm25_retriever = InMemoryBM25Retriever(document_store)
document_joiner = DocumentJoiner()
ranker = TransformersSimilarityRanker(model=ranker_model,top_k=8)
prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
gemini = GoogleAIGeminiChatGenerator(model="models/gemini-1.5-flash")
memory_store = InMemoryChatMessageStore()
memory_joiner = ListJoiner(List[ChatMessage])
memory_retriever = ChatMessageRetriever(memory_store)
memory_writer = ChatMessageWriter(memory_store)
pipe = Pipeline()
pipe.add_component("text_embedder", text_embedder)
pipe.add_component("embedding_retriever", embedding_retriever)
pipe.add_component("bm25_retriever", bm25_retriever)
pipe.add_component("document_joiner", document_joiner)
pipe.add_component("ranker", ranker)
pipe.add_component("prompt_builder", prompt_builder)
pipe.add_component("llm", gemini)
pipe.add_component("memory_retriever", memory_retriever)
pipe.add_component("memory_writer", memory_writer)
pipe.add_component("memory_joiner", memory_joiner)
pipe.connect("text_embedder", "embedding_retriever")
pipe.connect("bm25_retriever", "document_joiner")
pipe.connect("embedding_retriever", "document_joiner")
pipe.connect("document_joiner", "ranker")
pipe.connect("ranker.documents", "prompt_builder.documents")
pipe.connect("prompt_builder.prompt", "llm.messages")
pipe.connect("llm.replies", "memory_joiner")
pipe.connect("memory_joiner", "memory_writer")
pipe.connect("memory_retriever", "prompt_builder.memories")
self.pipe = pipe
def run(self,q):
now = datetime.datetime.now(pytz.timezone('Asia/Tokyo'))
print('q:',q,now)
if not q :
return {'reply':'','sources':''}
result = self.pipe.run({
"text_embedder": {"text": q},
"bm25_retriever": {"query": q},
"ranker": {"query": q},
"prompt_builder": { "query": q},
"memory_joiner": {"values": [ChatMessage.from_user(q)]},
},include_outputs_from=["llm",'ranker'])
reply = result['llm']['replies'][0]
docs = result['ranker']['documents']
print('reply:',reply)
html = '<div class="ref-title">参考記事</div><div class="ref">'
for doc in docs :
title = doc.meta['title']
link = doc.meta['link']
row = f'<div><a class="link" href="{link}" target="_blank">{title}</a></div>'
html += row
print('',title,link,doc.meta['type'],doc.score)
html += '</div>'
return {'reply':reply.content,'sources':html}
rag = Niwa_rag()
def fn(q,history):
result = rag.run(q)
return result['reply'] + result['sources']
app = gr.ChatInterface(
fn,
type="messages",
title='庭ファン Chatbot',
textbox=gr.Textbox(placeholder='質問を記入して下さい',submit_btn=True),
css_paths = './app.css'
)
if __name__ == "__main__":
app.launch()