|
|
|
"""wiki_chat_3_hack.ipynb |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1chXsWeq1LzbvYIs6H73gibYmNDRbIgkD |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
from sentence_transformers import SentenceTransformer, CrossEncoder, util |
|
from torch import tensor as torch_tensor |
|
from datasets import load_dataset |
|
|
|
"""# import models""" |
|
|
|
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') |
|
bi_encoder.max_seq_length = 256 |
|
|
|
|
|
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') |
|
|
|
"""# import datasets""" |
|
|
|
dataset = load_dataset("sritang/hack_df", split='train') |
|
mypassages = list(dataset.to_pandas()['paragraph']) |
|
mysource = list(dataset.to_pandas()['Source']) |
|
|
|
dataset_embed = load_dataset("sritang/hack_policy_embed_sri", split='train') |
|
dataset_embed_pd = dataset_embed.to_pandas() |
|
mycorpus_embeddings = torch_tensor(dataset_embed_pd.values) |
|
|
|
def search(query, passages = mypassages, doc_embedding = mycorpus_embeddings, top_k=20, top_n = 1): |
|
question_embedding = bi_encoder.encode(query, convert_to_tensor=True) |
|
question_embedding = question_embedding |
|
hits = util.semantic_search(question_embedding, doc_embedding, top_k=top_k) |
|
hits = hits[0] |
|
|
|
|
|
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits] |
|
cross_scores = cross_encoder.predict(cross_inp) |
|
|
|
|
|
for idx in range(len(cross_scores)): |
|
hits[idx]['cross-score'] = cross_scores[idx] |
|
|
|
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) |
|
predictions = hits[:top_n] |
|
return predictions |
|
|
|
|
|
|
|
def get_text(qry): |
|
predictions = search(qry) |
|
prediction_text = [] |
|
for hit in predictions: |
|
prediction_text.append("{}".format(mypassages[hit['corpus_id']])) |
|
return prediction_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""# new LLM based functions""" |
|
|
|
import os |
|
os.environ["OPENAI_API_KEY"] = "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" |
|
|
|
from langchain.llms import OpenAI |
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
|
from langchain.text_splitter import CharacterTextSplitter |
|
|
|
from langchain.docstore.document import Document |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains.question_answering import load_qa_chain |
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain |
|
from langchain.chains import VectorDBQAWithSourcesChain |
|
|
|
chain_qa = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="stuff") |
|
|
|
def get_text_fmt(qry, passages = mypassages, doc_embedding=mycorpus_embeddings): |
|
predictions = search(qry, passages = passages, doc_embedding = doc_embedding, top_n=5, ) |
|
prediction_text = [] |
|
for hit in predictions: |
|
page_content = passages[hit['corpus_id']] |
|
metadata = {"source": mysource[hit['corpus_id']]} |
|
result = Document(page_content=page_content, metadata=metadata) |
|
prediction_text.append(result) |
|
return prediction_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_llm_response(message): |
|
mydocs = get_text_fmt(message) |
|
responses = chain_qa.run(input_documents=mydocs, question=message) |
|
return responses |
|
|
|
"""# chat example""" |
|
|
|
def chat(message, history): |
|
history = history or [] |
|
message = message.lower() |
|
|
|
response = get_llm_response(message) |
|
history.append((message, response)) |
|
return history, history |
|
|
|
css=".gradio-container {background-color: lightgray}" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
history_state = gr.State() |
|
gr.Markdown('# Hack QA') |
|
title='Benefit Chatbot' |
|
description='chatbot with search on Health Benefits' |
|
with gr.Row(): |
|
chatbot = gr.Chatbot() |
|
with gr.Row(): |
|
message = gr.Textbox(label='Input your question here:', |
|
placeholder='What is the name of the plan described by this summary of benefits?', |
|
lines=1) |
|
submit = gr.Button(value='Send', |
|
variant='secondary').style(full_width=False) |
|
submit.click(chat, |
|
inputs=[message, history_state], |
|
outputs=[chatbot, history_state]) |
|
gr.Examples( |
|
examples=["How often can we have root canal treatment for h5216?", |
|
"Compare H1036 or H5216 based on monthly plan premium and recommend the best plan?", |
|
"I use urgent care a lot, what plan do you recommend based on low copay?", |
|
"I am looking for a plan sold in clark indiana, has transportation which is covered?", |
|
"Name a plan that is available in both clark county Indiana and Floyd in Kentucky"], |
|
inputs=message |
|
) |
|
|
|
demo.launch() |
|
|
|
|