Spaces:
Runtime error
Runtime error
import datetime | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import os | |
import json | |
import speech_recognition as sr | |
import re | |
import time | |
import spacy | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel | |
import pickle | |
import streamlit as st | |
from sklearn.metrics.pairwise import cosine_similarity | |
import run_tts | |
# Build the AI | |
class CelebBot(): | |
def __init__(self, name, QA_tokenizer, QA_model, sentTr_tokenizer, sentTr_model, spacy_model, knowledge_sents): | |
self.name = name | |
print("--- starting up", self.name, "---") | |
self.text = "" | |
self.QA_tokenizer = QA_tokenizer | |
self.QA_model = QA_model | |
self.sentTr_tokenizer = sentTr_tokenizer | |
self.sentTr_model = sentTr_model | |
self.spacy_model = spacy_model | |
self.all_knowledge = knowledge_sents | |
def get_seq2seq_model(self, _model_id): | |
return AutoModelForSeq2SeqLM.from_pretrained(_model_id) | |
def get_model(self,_model_id): | |
return AutoModel.from_pretrained(_model_id) | |
def get_tokenizer(self,_model_id): | |
return AutoTokenizer.from_pretrained(_model_id) | |
def speech_to_text(self): | |
recognizer = sr.Recognizer() | |
with sr.Microphone() as mic: | |
recognizer.adjust_for_ambient_noise(mic, duration=1) | |
# flag = input("Are you ready to record?\nProceed (Y/n)") | |
# try: | |
# assert flag=='Y' | |
# except: | |
# self.text = "" | |
# print(f"me --> Permission denied") | |
time.sleep(1) | |
print("listening") | |
audio = recognizer.listen(mic) | |
try: | |
self.text = recognizer.recognize_google(audio) | |
except: | |
self.text = "" | |
print(f"me --> No audio recognized") | |
def wake_up(self, text): | |
return True if "hey " + self.name in text.lower() else False | |
def text_to_speech(self, autoplay=True): | |
return run_tts.tts(self.text, "_".join(self.name.split(" ")), self.spacy_model, autoplay) | |
def sentence_embeds_inference(self, texts: list): | |
def _mean_pooling(model_output, attention_mask): | |
token_embeddings = model_output[0] #First element of model_output contains all token embeddings | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
# Tokenize sentences | |
encoded_input = self.sentTr_tokenizer(texts, padding=True, truncation=True, return_tensors='pt') | |
encoded_input["input_ids"] = encoded_input["input_ids"] | |
encoded_input["attention_mask"] = encoded_input["attention_mask"] | |
# Compute token embeddings | |
with torch.no_grad(): | |
model_output = self.sentTr_model(**encoded_input) | |
# Perform pooling | |
sentence_embeddings = _mean_pooling(model_output, encoded_input['attention_mask']) | |
# Normalize embeddings | |
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) | |
return sentence_embeddings | |
def retrieve_knowledge_assertions(self): | |
question_embeddings = self.sentence_embeds_inference([self.name + ', ' + self.text]) | |
all_knowledge_embeddings = self.sentence_embeds_inference(self.all_knowledge) | |
similarity = cosine_similarity(all_knowledge_embeddings.cpu(), question_embeddings.cpu()) | |
similarity = np.reshape(similarity, (1, -1))[0] | |
K = min(8, len(self.all_knowledge)) | |
top_K = np.sort(np.argpartition(similarity, -K)[-K: ]) | |
all_knowledge_assertions = np.array(self.all_knowledge)[top_K] | |
# similarities = np.array(similarity)[top_K] | |
# print(*list(zip(all_knowledge_assertions, similarities)), sep='\n') | |
return ' '.join(all_knowledge_assertions) | |
def question_answer(self, instruction1='', knowledge=''): | |
if self.text != "": | |
## wake up | |
if self.wake_up(self.text) is True: | |
self.text = f"Hello I am {self.name} the AI, what can I do for you?" | |
## have a conversation | |
else: | |
# if re.search(you_regex, self.text) != None: | |
instruction1 = f'[Instruction] You are a celebrity named {self.name}. You need to answer the question based on knowledge and commonsense.' | |
knowledge = self.retrieve_knowledge_assertions() | |
# else: | |
# instruction1 = f'[Instruction] You need to answer the question based on commonsense.' | |
query = f"{instruction1} [knowledge] {knowledge} [question] {self.text} {self.name}!" | |
input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids | |
outputs = self.QA_model.generate(input_ids, max_length=1024) | |
self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# instruction2 = f'[Instruction] You are a celebrity named {self.name}. You need to answer the question based on knowledge' | |
# query = f"{instruction2} [knowledge] {self.text} {answer} [question] {self.name}, {self.text}" | |
# input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids | |
# outputs = self.QA_model.generate(input_ids, max_length=1024) | |
# self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return self.text | |
def action_time(): | |
return f"it's {datetime.datetime.now().time().strftime('%H:%M')}" | |
def save_kb(kb, filename): | |
with open(filename, "wb") as f: | |
pickle.dump(kb, f) | |
def load_kb(filename): | |
res = None | |
with open(filename, "rb") as f: | |
res = pickle.load(f) | |
return res |