niwa-rag / app.py
karubiniumu's picture
Update app.py
13a0823 verified
raw
history blame
5.77 kB
import gradio as gr
sentense_transformers_model = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
ranker_model = 'hotchpotch/japanese-reranker-cross-encoder-base-v1'
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.joiners import DocumentJoiner
from haystack.components.rankers import TransformersSimilarityRanker
from haystack import Pipeline,component,Document
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator
from haystack.components.builders import ChatPromptBuilder
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
from haystack_experimental.components.retrievers import ChatMessageRetriever
from haystack_experimental.components.writers import ChatMessageWriter
from haystack.dataclasses import ChatMessage
from itertools import chain
from typing import Any,List
from haystack.core.component.types import Variadic
document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
print('document_store loaded' ,document_store.count_documents())
@component
class ListJoiner:
def __init__(self, _type: Any):
component.set_output_types(self, values=_type)
def run(self, values: Variadic[Any]):
result = list(chain(*values))
return {"values": result}
class Niwa_rag :
def __init__(self):
self.createPipe()
def createPipe(self):
user_message_template = """
γƒγƒ£γƒƒγƒˆε±₯ζ­΄γ¨ζδΎ›γ•γ‚ŒγŸθ³‡ζ–™γ«εŸΊγ₯いて、θ³ͺε•γ«η­”γˆγ¦γγ γ•γ„γ€‚
γƒγƒ£γƒƒγƒˆε±₯ζ­΄:
{% for memory in memories %}
{{ memory.content }}
{% endfor %}
資料:
{% for document in documents %}
{{ document.content }}
{% endfor %}
θ³ͺ問: {{query}}
ε›žη­”:
"""
system_message = ChatMessage.from_system("あγͺγŸγ―γ€ζδΎ›γ•γ‚ŒγŸθ³‡ζ–™γ¨γƒγƒ£γƒƒγƒˆε±₯歴を使用して人間を支援するAIγ‚’γ‚·γ‚Ήγ‚Ώγƒ³γƒˆγ§γ™")
user_message = ChatMessage.from_user(user_message_template)
messages = [system_message, user_message]
text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model)
embedding_retriever = InMemoryEmbeddingRetriever(document_store)
bm25_retriever = InMemoryBM25Retriever(document_store)
document_joiner = DocumentJoiner()
ranker = TransformersSimilarityRanker(model=ranker_model,top_k=8)
prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
gemini = GoogleAIGeminiChatGenerator(model="models/gemini-1.0-pro")
memory_store = InMemoryChatMessageStore()
memory_joiner = ListJoiner(List[ChatMessage])
memory_retriever = ChatMessageRetriever(memory_store)
memory_writer = ChatMessageWriter(memory_store)
pipe = Pipeline()
pipe.add_component("text_embedder", text_embedder)
pipe.add_component("embedding_retriever", embedding_retriever)
pipe.add_component("bm25_retriever", bm25_retriever)
pipe.add_component("document_joiner", document_joiner)
pipe.add_component("ranker", ranker)
pipe.add_component("prompt_builder", prompt_builder)
pipe.add_component("llm", gemini)
pipe.add_component("memory_retriever", memory_retriever)
pipe.add_component("memory_writer", memory_writer)
pipe.add_component("memory_joiner", memory_joiner)
pipe.connect("text_embedder", "embedding_retriever")
pipe.connect("bm25_retriever", "document_joiner")
pipe.connect("embedding_retriever", "document_joiner")
pipe.connect("document_joiner", "ranker")
pipe.connect("ranker.documents", "prompt_builder.documents")
pipe.connect("prompt_builder.prompt", "llm.messages")
pipe.connect("llm.replies", "memory_joiner")
pipe.connect("memory_joiner", "memory_writer")
pipe.connect("memory_retriever", "prompt_builder.memories")
self.pipe = pipe
def run(self,q):
print('q:',q)
if not q :
return {'reply':'','sources':''}
result = self.pipe.run({
"text_embedder": {"text": q},
"bm25_retriever": {"query": q},
"ranker": {"query": q},
"prompt_builder": { "query": q},
"memory_joiner": {"values": [ChatMessage.from_user(q)]},
},include_outputs_from=["llm",'ranker'])
reply = result['llm']['replies'][0]
docs = result['ranker']['documents']
print('reply:',reply)
html = '<div class="ref-title">ε‚θ€ƒθ¨˜δΊ‹</div><div class="ref">'
for doc in docs :
title = doc.meta['title']
link = doc.meta['link']
row = f'<div><a class="link" href="{link}" target="_blank">{title}</a></div>'
html += row
print('',title,link,doc.meta['type'],doc.score)
html += '</div>'
return {'reply':reply.content,'sources':html}
rag = Niwa_rag()
def fn(q,history):
result = rag.run(q)
return result['reply'] + result['sources']
app = gr.ChatInterface(
fn,
type="messages",
title='庭フゑン Chatbot',
textbox=gr.Textbox(placeholder='θ³ͺε•γ‚’θ¨˜ε…₯して下さい',submit_btn=True),
css_paths = './app.css'
)
if __name__ == "__main__":
app.launch()