File size: 5,813 Bytes
3ac45a5
 
ee324f3
 
 
3ac45a5
 
 
 
 
 
57957da
 
 
 
 
 
 
 
 
3ac45a5
 
 
 
57957da
 
 
 
 
 
 
 
3ac45a5
 
 
 
57957da
 
 
 
 
 
 
3ac45a5
57957da
3ac45a5
 
 
 
57957da
3ac45a5
 
57957da
 
 
 
ee324f3
3ac45a5
 
 
57957da
 
 
 
 
 
 
 
3ac45a5
 
 
 
 
 
 
 
 
57957da
 
 
3ac45a5
 
 
 
 
 
57957da
 
 
 
 
3ac45a5
 
57957da
3ac45a5
57957da
 
 
 
 
 
 
 
 
 
 
 
 
3ac45a5
 
57957da
3ac45a5
f143400
3ac45a5
57957da
3ac45a5
 
 
57957da
 
 
 
 
 
 
c6f3317
57957da
 
 
3ac45a5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr

sentense_transformers_model = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
ranker_model = 'hotchpotch/japanese-reranker-cross-encoder-base-v1' 

from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.joiners import DocumentJoiner
from haystack.components.rankers import TransformersSimilarityRanker
from haystack import Pipeline,component,Document
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator
from haystack.components.builders import ChatPromptBuilder
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
from haystack_experimental.components.retrievers import ChatMessageRetriever
from haystack_experimental.components.writers import ChatMessageWriter
from haystack.dataclasses import ChatMessage
from itertools import chain
from typing import Any,List
from haystack.core.component.types import Variadic

document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
print('document_store loaded' ,document_store.count_documents())

@component
class ListJoiner:
    def __init__(self, _type: Any):
        component.set_output_types(self, values=_type)
    def run(self, values: Variadic[Any]):
        result = list(chain(*values))
        return {"values": result}

class Niwa_rag :
    def __init__(self):
        self.createPipe()
    def createPipe(self):
        user_message_template = """
            会話の履歴と提供された資料に基づいて、質問に答えてください。
        
            会話の履歴:
            {% for memory in memories %}
                {{ memory.content }}
            {% endfor %}
        
            資料:
            {% for document in documents %}
                {{ document.content }}
            {% endfor %}
        
            質問: {{query}}
            回答:
        """
        system_message = ChatMessage.from_system("あなたは、提供された資料と会話履歴を使用して人間を支援するAIアシスタントです")
        user_message = ChatMessage.from_user(user_message_template)
        messages = [system_message, user_message]
        
        text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model)
        embedding_retriever = InMemoryEmbeddingRetriever(document_store)
        bm25_retriever = InMemoryBM25Retriever(document_store)
        document_joiner = DocumentJoiner()
        ranker = TransformersSimilarityRanker(model=ranker_model,top_k=8)
        prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
        gemini = GoogleAIGeminiChatGenerator(model="models/gemini-1.0-pro")
        
        memory_store = InMemoryChatMessageStore()
        memory_joiner = ListJoiner(List[ChatMessage])
        memory_retriever = ChatMessageRetriever(memory_store)
        memory_writer = ChatMessageWriter(memory_store)
        
        pipe = Pipeline()
        pipe.add_component("text_embedder", text_embedder)
        pipe.add_component("embedding_retriever", embedding_retriever)
        pipe.add_component("bm25_retriever", bm25_retriever)
        pipe.add_component("document_joiner", document_joiner)
        pipe.add_component("ranker", ranker)
        pipe.add_component("prompt_builder", prompt_builder)
        pipe.add_component("llm", gemini)
        pipe.add_component("memory_retriever", memory_retriever)
        pipe.add_component("memory_writer", memory_writer)
        pipe.add_component("memory_joiner", memory_joiner)
        
        pipe.connect("text_embedder", "embedding_retriever")
        pipe.connect("bm25_retriever", "document_joiner")
        pipe.connect("embedding_retriever", "document_joiner")
        pipe.connect("document_joiner", "ranker")
        pipe.connect("ranker.documents", "prompt_builder.documents")
        pipe.connect("prompt_builder.prompt", "llm.messages")
        pipe.connect("llm.replies", "memory_joiner")
        pipe.connect("memory_joiner", "memory_writer")
        pipe.connect("memory_retriever", "prompt_builder.memories")
        
        self.pipe = pipe
    def run(self,q):
        print('q:',q)
        if not q :
            return {'reply':'','sources':''}
        result = self.pipe.run({
            "text_embedder": {"text": q},
            "bm25_retriever": {"query": q},
            "ranker": {"query": q},
            "prompt_builder": { "query": q},
            "memory_joiner": {"values": [ChatMessage.from_user(q)]},
        },include_outputs_from=["llm",'ranker'])
        reply = result['llm']['replies'][0]
        docs = result['ranker']['documents']
        print('reply:',reply)
        html = '<div class="ref-title">参考記事</div><div class="ref">'
        for doc in docs :
            title = doc.meta['title']
            link = doc.meta['link']
            row = f'<div><a class="link" href="{link}" target="_blank">{title}</a></div>'
            html += row
            print('',title,link,doc.meta['type'],doc.score)
        html += '</div>'
        return {'reply':reply.content,'sources':html}

rag = Niwa_rag()

def fn(q,history):
    result = rag.run(q)
    return result['reply'] + result['sources']

app = gr.ChatInterface(
    fn, 
    type="messages",
    textbox=gr.Textbox(placeholder='質問を記入して下さい',submit_btn=True),
    css='.ref-title{margin-top:10px} .ref .link{margin-left:20px;font-size:90%;color:-webkit-link; !important}'
)


if __name__ == "__main__":
    app.launch(share=True)