niwa-rag / app.py
karubiniumu's picture
hybrid + emb_only
804908b
import gradio as gr
from pipe import pipe
import pytz
import datetime
from haystack.dataclasses import ChatMessage
def log_docs(docs):
for doc in docs :
title = doc.meta['title']
link = doc.meta['link']
print('',title,link,doc.meta['type'],doc.score)
def get_unique_docs(docs):
source_ids = set([doc.meta['source_id'] for doc in docs])
_docs = sorted([[doc for doc in docs if doc.meta['source_id']==source_id][0] for source_id in source_ids],key=lambda x:x.score,reverse=True)
return _docs
def run(pipe,q):
now = datetime.datetime.now(pytz.timezone('Asia/Tokyo'))
print('\nq:',q,now)
if not q :
return {'reply':'','sources':''}
result = pipe.run({
"query_rephrase_prompt_builder":{"query":q},
"prompt_builder": { "query": q},
"prompt_builder_emb_only": { "query": q},
"memory_joiner": {"values": [ChatMessage.from_user(q)]},
},include_outputs_from=["query_rephrase_llm","llm",'ranker','bm25_retriever','embedding_retriever','llm_emb_only'])
query_rephrase = result['query_rephrase_llm']['replies'][0]
print('query_rephrase:',query_rephrase)
for retriever in ['bm25_retriever','embedding_retriever'] :
print(retriever)
docs = result[retriever]['documents']
log_docs(docs)
reply = result['llm']['replies'][0]
docs = result['ranker']['documents']
print('ranker')
log_docs(docs)
print('reply:',reply.content)
def create_response(reply,docs) :
html = '<div class="ref-title">参考記事</div><div class="ref">'
for doc in get_unique_docs(docs) :
title = doc.meta['title']
link = doc.meta['link']
row = f'<div><a class="link" href="{link}" target="_blank">{title}</a></div>'
html += row
html += '</div>'
return {'reply':reply.content,'sources':html}
hybrid_response = create_response(reply,docs)
reply_emb_only = result['llm_emb_only']['replies'][0]
print('reply_emb_only :',reply_emb_only.content)
emb_only_response = create_response(reply_emb_only,result['embedding_retriever']['documents'])
return {'hybrid':hybrid_response,'emb_only' :emb_only_response}
def rag(q,history):
result = run(pipe,q)
return '<h3>ハイブリッド検索</h3>' + result['hybrid']['reply'] + result['hybrid']['sources'] + \
'<h3>ベクトル検索のみ</h3>'+ result['emb_only']['reply'] + result['emb_only']['sources']
app = gr.ChatInterface(
rag,
type="messages",
title='庭ファン Chatbot',
textbox=gr.Textbox(placeholder='質問を記入して下さい',submit_btn=True),
css_paths = './app.css'
)
if __name__ == "__main__":
app.launch()