File size: 6,063 Bytes
3ac45a5
de2d767
b99fbc5
3ac45a5
ee324f3
 
 
3ac45a5
 
 
 
 
 
57957da
 
 
 
 
 
 
 
 
3ac45a5
 
 
 
57957da
 
 
 
 
 
 
 
3ac45a5
 
 
 
57957da
13a0823
46b2a15
 
57957da
13a0823
57957da
 
 
3ac45a5
57957da
3ac45a5
 
 
 
57957da
3ac45a5
 
13a0823
57957da
 
 
ee324f3
3ac45a5
 
 
57957da
 
7068013
57957da
 
 
 
 
3ac45a5
 
 
 
 
 
 
 
 
57957da
 
 
3ac45a5
 
 
 
 
 
57957da
 
 
 
 
3ac45a5
 
b99fbc5
 
3ac45a5
57957da
 
 
 
 
 
 
 
 
 
 
 
 
3ac45a5
 
57957da
3ac45a5
f143400
3ac45a5
57957da
3ac45a5
 
 
57957da
 
 
 
 
 
 
13a0823
c6f3317
13a0823
57957da
 
3ac45a5
 
13a0823
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import datetime
import pytz

sentense_transformers_model = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
ranker_model = 'hotchpotch/japanese-reranker-cross-encoder-base-v1' 

from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.joiners import DocumentJoiner
from haystack.components.rankers import TransformersSimilarityRanker
from haystack import Pipeline,component,Document
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator
from haystack.components.builders import ChatPromptBuilder
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
from haystack_experimental.components.retrievers import ChatMessageRetriever
from haystack_experimental.components.writers import ChatMessageWriter
from haystack.dataclasses import ChatMessage
from itertools import chain
from typing import Any,List
from haystack.core.component.types import Variadic

document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
print('document_store loaded' ,document_store.count_documents())

@component
class ListJoiner:
    def __init__(self, _type: Any):
        component.set_output_types(self, values=_type)
    def run(self, values: Variadic[Any]):
        result = list(chain(*values))
        return {"values": result}

class Niwa_rag :
    def __init__(self):
        self.createPipe()
    def createPipe(self):
        user_message_template = """
            チャット履歴と提供された資料に基づいて、質問に答えてください。
            資料はチャット履歴の一部ではないことに注意してください。
            質問が資料から回答できない場合は、その旨を述べてください。
        
            チャット履歴:
            {% for memory in memories %}
                {{ memory.content }}
            {% endfor %}
        
            資料:
            {% for document in documents %}
                {{ document.content }}
            {% endfor %}
        
            質問: {{query}}
            回答:
        """
        system_message = ChatMessage.from_system("あなたは、提供された資料とチャット履歴を使用して人間を支援するAIアシスタントです")
        user_message = ChatMessage.from_user(user_message_template)
        messages = [system_message, user_message]
        
        text_embedder = SentenceTransformersTextEmbedder(model=sentense_transformers_model)
        embedding_retriever = InMemoryEmbeddingRetriever(document_store)
        bm25_retriever = InMemoryBM25Retriever(document_store)
        document_joiner = DocumentJoiner()
        ranker = TransformersSimilarityRanker(model=ranker_model,top_k=8)
        prompt_builder = ChatPromptBuilder(template=messages,variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])
        gemini = GoogleAIGeminiChatGenerator(model="models/gemini-1.5-flash")
        
        memory_store = InMemoryChatMessageStore()
        memory_joiner = ListJoiner(List[ChatMessage])
        memory_retriever = ChatMessageRetriever(memory_store)
        memory_writer = ChatMessageWriter(memory_store)
        
        pipe = Pipeline()
        pipe.add_component("text_embedder", text_embedder)
        pipe.add_component("embedding_retriever", embedding_retriever)
        pipe.add_component("bm25_retriever", bm25_retriever)
        pipe.add_component("document_joiner", document_joiner)
        pipe.add_component("ranker", ranker)
        pipe.add_component("prompt_builder", prompt_builder)
        pipe.add_component("llm", gemini)
        pipe.add_component("memory_retriever", memory_retriever)
        pipe.add_component("memory_writer", memory_writer)
        pipe.add_component("memory_joiner", memory_joiner)
        
        pipe.connect("text_embedder", "embedding_retriever")
        pipe.connect("bm25_retriever", "document_joiner")
        pipe.connect("embedding_retriever", "document_joiner")
        pipe.connect("document_joiner", "ranker")
        pipe.connect("ranker.documents", "prompt_builder.documents")
        pipe.connect("prompt_builder.prompt", "llm.messages")
        pipe.connect("llm.replies", "memory_joiner")
        pipe.connect("memory_joiner", "memory_writer")
        pipe.connect("memory_retriever", "prompt_builder.memories")
        
        self.pipe = pipe
    def run(self,q):
        now = datetime.datetime.now(pytz.timezone('Asia/Tokyo'))
        print('q:',q,now)
        if not q :
            return {'reply':'','sources':''}
        result = self.pipe.run({
            "text_embedder": {"text": q},
            "bm25_retriever": {"query": q},
            "ranker": {"query": q},
            "prompt_builder": { "query": q},
            "memory_joiner": {"values": [ChatMessage.from_user(q)]},
        },include_outputs_from=["llm",'ranker'])
        reply = result['llm']['replies'][0]
        docs = result['ranker']['documents']
        print('reply:',reply)
        html = '<div class="ref-title">参考記事</div><div class="ref">'
        for doc in docs :
            title = doc.meta['title']
            link = doc.meta['link']
            row = f'<div><a class="link" href="{link}" target="_blank">{title}</a></div>'
            html += row
            print('',title,link,doc.meta['type'],doc.score)
        html += '</div>'
        return {'reply':reply.content,'sources':html}

rag = Niwa_rag()

def fn(q,history):
    result = rag.run(q)
    return result['reply'] + result['sources']

app = gr.ChatInterface(
    fn, 
    type="messages",
    title='庭ファン Chatbot',
    textbox=gr.Textbox(placeholder='質問を記入して下さい',submit_btn=True),
    css_paths = './app.css'
)


if __name__ == "__main__":
    app.launch()