karubiniumu commited on
Commit
3ac45a5
1 Parent(s): 71dbfab
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. .gitignore +0 -0
  3. app.py +93 -0
  4. document_store.json +3 -0
  5. requirements.txt +7 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ document_store.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
File without changes
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from haystack.document_stores.in_memory import InMemoryDocumentStore
4
+ from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
5
+ from haystack.components.embedders import SentenceTransformersTextEmbedder
6
+ from haystack.components.joiners import DocumentJoiner
7
+ from haystack.components.rankers import TransformersSimilarityRanker
8
+ from haystack import Pipeline,component,Document
9
+ from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator
10
+ from haystack.components.builders import PromptBuilder
11
+
12
+ @component
13
+ class AnswerBuilder :
14
+ @component.output_types(reply=str,documents=list[Document])
15
+ def run(self,replies:list[str],documents:list[Document]):
16
+ reply = replies[0]
17
+ return {"reply":reply,'documents':documents}
18
+
19
+ document_store = InMemoryDocumentStore.load_from_disk(path='./document_store.json')
20
+ print('document_store loaded' ,document_store.count_documents())
21
+
22
+ class Niwa_rag :
23
+ def __init__(self):
24
+ self.createPipe()
25
+ def createPipe(self):
26
+ template = """
27
+ 以下の情報に基づいて質問に答えて下さい。
28
+
29
+ Context:
30
+ {% for document in documents %}
31
+ {{ document.content }}
32
+ {% endfor %}
33
+
34
+ 質問: {{question}}
35
+ 回答:
36
+ """
37
+ text_embedder = SentenceTransformersTextEmbedder()
38
+ embedding_retriever = InMemoryEmbeddingRetriever(document_store)
39
+ bm25_retriever = InMemoryBM25Retriever(document_store)
40
+ document_joiner = DocumentJoiner()
41
+ ranker = TransformersSimilarityRanker(top_k=6)
42
+ prompt_builder = PromptBuilder(template=template)
43
+ gemini = GoogleAIGeminiGenerator(model="models/gemini-1.0-pro")
44
+ answer_builder = AnswerBuilder()
45
+
46
+ pipe = Pipeline()
47
+ pipe.add_component("text_embedder", text_embedder)
48
+ pipe.add_component("embedding_retriever", embedding_retriever)
49
+ pipe.add_component("bm25_retriever", bm25_retriever)
50
+ pipe.add_component("document_joiner", document_joiner)
51
+ pipe.add_component("ranker", ranker)
52
+ pipe.add_component("prompt_builder", prompt_builder)
53
+ pipe.add_component("llm", gemini)
54
+ pipe.add_component("answer_builder", answer_builder)
55
+
56
+ pipe.connect("text_embedder", "embedding_retriever")
57
+ pipe.connect("bm25_retriever", "document_joiner")
58
+ pipe.connect("embedding_retriever", "document_joiner")
59
+ pipe.connect("document_joiner", "ranker")
60
+ pipe.connect("ranker.documents", "prompt_builder.documents")
61
+ pipe.connect("prompt_builder.prompt", "llm")
62
+ pipe.connect("llm.replies", "answer_builder.replies")
63
+ pipe.connect("ranker.documents", "answer_builder.documents")
64
+ self.pipe = pipe
65
+ def run(self,q):
66
+ if not q :
67
+ return ['','']
68
+ result = self.pipe.run(
69
+ {"text_embedder": {"text": q}, "bm25_retriever": {"query": q}, "ranker": {"query": q},'prompt_builder':{'question':q}}
70
+ )
71
+ reply = result['answer_builder']['reply']
72
+ html = '<div>参考記事</div><div style="margin-left:30px">'
73
+ for doc in result['answer_builder']['documents'] :
74
+ title = doc.meta['title']
75
+ link = doc.meta['link']
76
+ row = f'<div><a href="{link}" target="_blank">{title}</a></div>'
77
+ html += row
78
+ html += '</div>'
79
+ return [reply,html]
80
+
81
+ rag = Niwa_rag()
82
+
83
+ with gr.Blocks() as app:
84
+ inputs=gr.Textbox(label='質問')
85
+ submit = gr.Button("送信",variant="primary")
86
+ reply =gr.Textbox(label='回答')
87
+ sources =gr.HTML(label='参考記事')
88
+ submit.click(lambda: gr.update(interactive=False),inputs=None, outputs=submit) \
89
+ .then(fn=rag.run, inputs=inputs, outputs=[reply,sources] ) \
90
+ .then(fn=lambda: gr.update(interactive=True),inputs=None, outputs=submit)
91
+
92
+ if __name__ == "__main__":
93
+ app.launch(share=True)
document_store.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4c44c65ed78f60d71791e1d1dc28451984561427236f4c231c013e9590d37cd
3
+ size 41714621
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ gradio
4
+ haystack-ai
5
+ google-ai-haystack
6
+ accelerate
7
+ sentence-transformers