lhzstar commited on
Commit
8bc0535
1 Parent(s): 7fedcd4

new commits

Browse files
Files changed (2) hide show
  1. app.py +9 -2
  2. celebbot.py +36 -11
app.py CHANGED
@@ -43,7 +43,10 @@ def main():
43
  def example_submit(text):
44
  st.session_state["prompt_from_text"] = text
45
 
46
- st.session_state["celeb_name"] = st.sidebar.selectbox('Choose your celebrity crush', options=list(celeb_data.keys()))
 
 
 
47
  model_id=st.sidebar.selectbox("Choose Your model",options=model_list)
48
 
49
  st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id
@@ -102,7 +105,11 @@ def main():
102
  st.session_state["messages"].append({"role": "user", "content": prompt})
103
 
104
  # Add assistant response to chat history
105
- response = st.session_state["celeb_bot"].question_answer()
 
 
 
 
106
 
107
  # disable autoplay to play in HTML
108
  b64 = st.session_state["celeb_bot"].text_to_speech(autoplay=False)
 
43
  def example_submit(text):
44
  st.session_state["prompt_from_text"] = text
45
 
46
+ def clear_chat_his():
47
+ st.session_state["messages"] = []
48
+
49
+ st.sidebar.selectbox('Choose your celebrity crush', key="celeb_name", options=sorted(list(celeb_data.keys())), on_change=clear_chat_his)
50
  model_id=st.sidebar.selectbox("Choose Your model",options=model_list)
51
 
52
  st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id
 
105
  st.session_state["messages"].append({"role": "user", "content": prompt})
106
 
107
  # Add assistant response to chat history
108
+ if len(st.session_state["messages"]) < 3:
109
+ response = st.session_state["celeb_bot"].question_answer()
110
+ else:
111
+ chat_his = "Question: {q}\n\nAnswer: {a}\n\n".format(q=st.session_state["messages"][-3]["content"], a=st.session_state["messages"][-2]["content"])
112
+ response = st.session_state["celeb_bot"].question_answer(chat_his=chat_his)
113
 
114
  # disable autoplay to play in HTML
115
  b64 = st.session_state["celeb_bot"].text_to_speech(autoplay=False)
celebbot.py CHANGED
@@ -97,7 +97,25 @@ class CelebBot():
97
 
98
  def third_to_first_person(self, text):
99
  text = text.replace(" ", " ")
100
- name = self.name.split(" ")[-1].lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  doc = self.spacy_model(text)
102
  transformed_text = []
103
 
@@ -109,10 +127,14 @@ class CelebBot():
109
  transformed_text.append("me")
110
  elif token.text.lower() == "his":
111
  transformed_text.append("my")
112
- elif token.text.lower() == name and token.dep_ in ["nsubj", "nsubjpass"]:
 
 
113
  transformed_text.append("I")
114
- elif token.text in ["'s", "’s"] and doc[i-1].text.lower() == name:
115
  transformed_text[-1] = "my"
 
 
116
  elif token.text.lower() == "their":
117
  transformed_text.append("our")
118
  elif token.text.lower() == "they":
@@ -127,10 +149,14 @@ class CelebBot():
127
  transformed_text.append("my")
128
  else:
129
  transformed_text.append("me")
130
- elif token.text.lower() == name and token.dep_ in ["nsubj", "nsubjpass"]:
 
 
131
  transformed_text.append("I")
132
- elif token.text in ["'s", "’s"] and doc[i-1].text.lower() == name:
133
  transformed_text[-1] = "my"
 
 
134
  elif token.text.lower() == "their":
135
  transformed_text.append("our")
136
  elif token.text.lower() == "they":
@@ -140,19 +166,18 @@ class CelebBot():
140
 
141
  return "".join(transformed_text)
142
 
143
- def question_answer(self, instruction='', knowledge=''):
144
  instruction = f"Your name is {self.name}. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
145
  if self.text != "":
146
- if re.search(re.compile(rf'\b(you|your|yours)\b', flags=re.IGNORECASE), self.text) != None:
147
- knowledge = self.retrieve_knowledge_assertions()
148
- else:
149
  knowledge = self.retrieve_knowledge_assertions(change_person=False)
 
 
150
 
151
- query = f"Context: {instruction} {knowledge}\n\nQuestion: {self.text}\n\nAnswer:"
152
  input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids.to(self.QA_model.device)
153
  outputs = self.QA_model.generate(input_ids, max_length=1024, min_length=8, do_sample=True, temperature=0.2, repetition_penalty=2.5)
154
  self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)
155
- # self.text = " ".join([i.text.strip().capitalize() for i in self.spacy_model(self.text).sents])
156
  return self.text
157
 
158
  @staticmethod
 
97
 
98
  def third_to_first_person(self, text):
99
  text = text.replace(" ", " ")
100
+ possible_names = [name.lower() for name in self.name.split(" ")]
101
+ if "bundchen" in self.name.lower():
102
+ possible_names.append("bündchen")
103
+ if "beyonce" in self.name.lower():
104
+ possible_names.append("beyoncé")
105
+ if "adele" in self.name.lower():
106
+ possible_names.append("adkins")
107
+ if "katy perry" in self.name.lower():
108
+ possible_names.append("hudson")
109
+ if "lady gaga" in self.name.lower():
110
+ possible_names.append("germanotta")
111
+ if "michelle obama" in self.name.lower():
112
+ possible_names.append("robinson")
113
+ if "natalie portman" in self.name.lower():
114
+ possible_names.append("hershlag")
115
+ if "rihanna" in self.name.lower():
116
+ possible_names.append("fenty")
117
+ if "victoria beckham" in self.name.lower():
118
+ possible_names.append("adams")
119
  doc = self.spacy_model(text)
120
  transformed_text = []
121
 
 
127
  transformed_text.append("me")
128
  elif token.text.lower() == "his":
129
  transformed_text.append("my")
130
+ elif token.text.lower() == "himself":
131
+ transformed_text.append("myself")
132
+ elif token.text.lower() in possible_names and token.dep_ in ["nsubj", "nsubjpass"]:
133
  transformed_text.append("I")
134
+ elif token.text in ["'s", "’s"] and doc[i-1].text.lower() in possible_names:
135
  transformed_text[-1] = "my"
136
+ elif token.text.lower() in possible_names and token.dep_ in ["dobj", "dative"]:
137
+ transformed_text.append("me")
138
  elif token.text.lower() == "their":
139
  transformed_text.append("our")
140
  elif token.text.lower() == "they":
 
149
  transformed_text.append("my")
150
  else:
151
  transformed_text.append("me")
152
+ elif token.text.lower() == "herself":
153
+ transformed_text.append("myself")
154
+ elif token.text.lower() in possible_names and token.dep_ in ["nsubj", "nsubjpass"]:
155
  transformed_text.append("I")
156
+ elif token.text in ["'s", "’s"] and doc[i-1].text.lower() in possible_names:
157
  transformed_text[-1] = "my"
158
+ elif token.text.lower() in possible_names and token.dep_ in ["dobj", "dative"]:
159
+ transformed_text.append("me")
160
  elif token.text.lower() == "their":
161
  transformed_text.append("our")
162
  elif token.text.lower() == "they":
 
166
 
167
  return "".join(transformed_text)
168
 
169
+ def question_answer(self, instruction='', knowledge='', chat_his=''):
170
  instruction = f"Your name is {self.name}. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
171
  if self.text != "":
172
+ if re.search(re.compile(rf'\b({self.name})\b', flags=re.IGNORECASE), self.text) != None:
 
 
173
  knowledge = self.retrieve_knowledge_assertions(change_person=False)
174
+ else:
175
+ knowledge = self.retrieve_knowledge_assertions()
176
 
177
+ query = f"Context: {instruction} {knowledge}\n\nChat History: {chat_his}Question: {self.text}\n\nAnswer:"
178
  input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids.to(self.QA_model.device)
179
  outputs = self.QA_model.generate(input_ids, max_length=1024, min_length=8, do_sample=True, temperature=0.2, repetition_penalty=2.5)
180
  self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
181
  return self.text
182
 
183
  @staticmethod