Sheldon_Retrieval_chat_bot / retrieve_bot.py
shakhovak
updates
ebb1297
raw history blame
No virus
4.28 kB
import pandas as pd
import pickle
import random
from sentence_transformers import SentenceTransformer
from utils import (
encode,
cosine_sim,
top_candidates,
candidates_reranking,
intent_classification,
)
from collections import deque
from transformers import pipeline
import torch
from transformers import AutoTokenizer
from dialog_tag import DialogTag
# this class representes main functions of retrieve bot
low_scoring_list = [
"What does it mean?",
"You have two strikes. Three strikes and you’ re out. It’ s a sports metaphor. Explain again!",
"Again, urban slang. In which, I believe I’ m gaining remarkable fluency. So, could you repeat?",
"I’m confused.",
"I can’t comment without violating our agreement that I don’ t criticize you.",
"Oh!",
"I need to use the restroom.",
"Move. Move. Move!",
"I was going to mention it at the time, but then I thought, some day maybe...",
"Well...",
"Apparently... I have no idea!?",
"I’m not sure...",
"Nothing. I say nothing.",
"Well, my friend. Focus and repeat!",
]
class ChatBot:
def __init__(self):
self.vect_data = []
self.scripts = []
self.conversation_history = deque([], maxlen=5)
self.tag_model = None
self.ranking_model = None
self.reranking_model = None
self.device = None
self.tokenizer = None
def load(self):
""" "This method is called first to load all datasets and
model used by the chat bot; all the data to be saved in
tha data folder, models to be loaded from hugging face"""
with open("data/scripts_vectors.pkl", "rb") as fp:
self.vect_data = pickle.load(fp)
self.scripts = pd.read_pickle("data/scripts.pkl")
self.tag_model = DialogTag("distilbert-base-uncased")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.ranking_model = SentenceTransformer(
"Shakhovak/chatbot_sentence-transformer"
) # # sentence-transformers/LaBSE or sentence-transformers/all-mpnet-base-v2 or Shakhovak/chatbot_sentence-transformer
self.tokenizer_reranking = AutoTokenizer.from_pretrained("bert-base-uncased")
self.reranking_model = pipeline(
model="Shakhovak/RerankerModel_chat_bot",
device=self.device,
tokenizer=self.tokenizer_reranking,
)
def generate_response(self, utterance: str) -> str:
"""this functions identifies potential
candidates for answer and ranks them"""
intent = intent_classification(utterance, utterance, self.tag_model)
query_encoding = encode(
texts=utterance,
intent=intent,
model=self.ranking_model,
contexts=self.conversation_history,
)
bot_cosine_scores = cosine_sim(
self.vect_data,
query_encoding,
)
top_scores, top_indexes = top_candidates(
bot_cosine_scores, intent=intent, initial_data=self.scripts, top=10
)
if top_scores[0] < 0.9:
answer = random.choice(low_scoring_list)
self.conversation_history.clear()
else:
# test candidates and collects them with label 0 to dictionary
reranked_dict = candidates_reranking(
top_indexes,
self.conversation_history,
utterance,
self.scripts,
self.reranking_model,
)
# if any candidates were selected, range them and pick up the top
# else keep up the initial top 1
if len(reranked_dict) >= 1:
updated_top_candidates = dict(
sorted(reranked_dict.items(), key=lambda item: item[1])
)
answer = self.scripts.iloc[list(updated_top_candidates.keys())[0]][
"answer"
]
else:
answer = self.scripts.iloc[top_indexes[0]]["answer"]
self.conversation_history.append(utterance)
self.conversation_history.append(answer)
return answer
# katya = ChatBot()
# katya.load()
# katya.generate_response("hi man!")