Spaces:
Runtime error
Runtime error
""" | |
converse.py - this script has functions for handling the conversation between the user and the bot. | |
https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/model#transformers.generation_utils.GenerationMixin.generate.no_repeat_ngram_size | |
""" | |
import pprint as pp | |
import time | |
import torch | |
import transformers | |
from grammar_improve import remove_trailing_punctuation | |
def discussion( | |
prompt_text: str, | |
speaker: str, | |
responder: str, | |
pipeline, | |
timeout=45, | |
max_length=128, | |
top_p=0.95, | |
top_k=50, | |
temperature=0.7, | |
full_text=False, | |
num_return_sequences=1, | |
device=-1, | |
verbose=False, | |
): | |
""" | |
discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot. | |
Parameters | |
---------- | |
prompt_text : str, the prompt to ask the bot, usually the user's question | |
speaker : str, the name of the person who is speaking the prompt | |
responder : str, the name of the person who is responding to the prompt | |
pipeline : transformers.Pipeline, the pipeline to use for generating the response | |
timeout : int, optional, the number of seconds to wait before timing out, by default 45 | |
max_length : int, optional, the maximum number of tokens to generate, defaults to 128 | |
top_p : float, optional, the top probability to use for sampling, defaults to 0.95 | |
top_k : int, optional, the top k to use for sampling, defaults to 50 | |
temperature : float, optional, the temperature to use for sampling, defaults to 0.7 | |
full_text : bool, optional, whether to return the full text or just the generated text, defaults to False | |
num_return_sequences : int, optional, the number of sequences to return, defaults to 1 | |
device : int, optional, the device to use for generation, defaults to -1 (CPU) | |
verbose : bool, optional, whether to print the generated text, defaults to False | |
Returns | |
------- | |
str, the generated text | |
""" | |
p_list = [] # track conversation | |
p_list.append(speaker.lower() + ":" + "\n") | |
p_list.append(prompt_text.lower() + "\n") | |
p_list.append("\n") | |
p_list.append(responder.lower() + ":" + "\n") | |
this_prompt = "".join(p_list) | |
if verbose: | |
print("overall prompt:\n") | |
pp.pprint(this_prompt, indent=4) | |
# call the model | |
print("\n... generating...") | |
bot_dialogue = gen_response( | |
this_prompt, | |
pipeline, | |
speaker, | |
responder, | |
timeout=timeout, | |
max_length=max_length, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
full_text=full_text, | |
num_return_sequences=num_return_sequences, | |
device=device, | |
verbose=verbose, | |
) | |
if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1: | |
bot_resp = ", ".join(bot_dialogue) | |
elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1: | |
bot_resp = bot_dialogue[0] | |
else: | |
bot_resp = bot_dialogue | |
bot_resp = " ".join(bot_resp) if isinstance(bot_resp, list) else bot_resp | |
bot_resp = bot_resp.strip() | |
# remove the last ',' '.' chars | |
bot_resp = remove_trailing_punctuation(bot_resp) | |
if verbose: | |
print("\n... bot response:\n") | |
pp.pprint(bot_resp) | |
p_list.append(bot_resp + "\n") | |
p_list.append("\n") | |
print("\nfinished!") | |
# return the bot response and the full conversation | |
return {"out_text": bot_resp, "full_conv": p_list} | |
def gen_response( | |
query: str, | |
pipeline, | |
speaker: str, | |
responder: str, | |
timeout=45, | |
max_length=128, | |
top_p=0.95, | |
top_k=50, | |
temperature=0.7, | |
full_text=False, | |
num_return_sequences=1, | |
device=-1, | |
verbose=False, | |
**kwargs, | |
): | |
""" | |
gen_response - a function that takes in a prompt and generates a response using the pipeline. This operates underneath the discussion function. | |
Parameters | |
---------- | |
query : str, the prompt to ask the bot, usually the user's question | |
speaker : str, the name of the person who is speaking the prompt | |
responder : str, the name of the person who is responding to the prompt | |
pipeline : transformers.Pipeline, the pipeline to use for generating the response | |
timeout : int, optional, the number of seconds to wait before timing out, by default 45 | |
max_length : int, optional, the maximum number of tokens to generate, defaults to 128 | |
top_p : float, optional, the top probability to use for sampling, defaults to 0.95 | |
top_k : int, optional, the top k to use for sampling, defaults to 50 | |
temperature : float, optional, the temperature to use for sampling, defaults to 0.7 | |
full_text : bool, optional, whether to return the full text or just the generated text, defaults to False | |
num_return_sequences : int, optional, the number of sequences to return, defaults to 1 | |
device : int, optional, the device to use for generation, defaults to -1 (CPU) | |
verbose : bool, optional, whether to print the generated text, defaults to False | |
Returns | |
------- | |
str, the generated text | |
""" | |
if max_length > 1024: | |
max_length = 1024 | |
print("max_length is too large, setting to 1024") | |
st = time.perf_counter() | |
response = pipeline( | |
query, | |
max_length=max_length, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
num_return_sequences=num_return_sequences, | |
max_time=timeout, | |
return_full_text=full_text, | |
no_repeat_ngram_size=3, | |
length_penalty=0.3, | |
repetition_penalty=3.4, | |
clean_up_tokenization_spaces=True, | |
**kwargs, | |
) # the likely better beam-less method | |
rt = round(time.perf_counter() - st, 2) | |
if verbose: | |
print(f"took {rt} sec to respond") | |
if verbose: | |
print("\n[DEBUG] generated:\n") | |
pp.pprint(response) # for debugging | |
# process the full result to get the ~bot response~ piece | |
this_result = str(response[0]["generated_text"]).split( | |
"\n" | |
) # TODO: adjust hardcoded value for index to dynamic (if n>1) | |
bot_dialogue = consolidate_texts( | |
name_resp=responder, | |
model_resp=this_result, | |
name_spk=speaker, | |
verbose=verbose, | |
print_debug=True, | |
) | |
if verbose: | |
print(f"DEBUG: {bot_dialogue} was original response pre-SC") | |
return bot_dialogue # | |
def consolidate_texts( | |
model_resp: list, | |
name_resp: str = None, | |
name_spk: str = None, | |
verbose=False, | |
print_debug=False, | |
): | |
""" | |
consolidate_texts - given a list with speaker name followed by speaker text, returns all consecutive values of the first speaker name | |
Parameters: | |
name_resp (str): the name of the person who is responding | |
model_resp (list): the list of strings to consolidate (usually from the model) | |
name_spk (str): the name of the person who is speaking | |
verbose (bool): whether to print the results | |
print_debug (bool): whether to print the debug info during looping | |
Returns: | |
list, a list of all the consecutive messages of the first speaker name | |
""" | |
assert len(model_resp) > 0, "model_resp is empty" | |
if len(model_resp) == 1: | |
return model_resp[0] | |
name_resp = "person beta" if name_resp is None else name_resp | |
name_spk = "person alpha" if name_spk is None else name_spk | |
if verbose: | |
print("====" * 10) | |
print(f"\n[DEBUG] initial model_resp has {len(model_resp)} lines: \n\t{model_resp}") | |
print(f" the first element is \n\t{model_resp[0]} and it is {type(model_resp[0])}") | |
fn_resp = [] | |
name_counter = 0 | |
break_safe = False | |
for resline in model_resp: | |
if name_resp.lower() in resline: | |
name_counter += 1 | |
break_safe = True # know the line is from bot as this line starts with the name of the bot | |
continue # don't add this line to the list | |
if name_spk.lower() in resline.lower(): | |
if print_debug: | |
print(f"\nDEBUG: \n\t{resline}\ncaused the break") | |
break # the name of the speaker is in the line, so we're done | |
if any([": " in resline,":\n" in resline]) and name_resp.lower() not in resline.lower(): | |
if print_debug: | |
print(f"\nDEBUG: \n\t{resline}\ncaused the break") | |
break | |
else: | |
fn_resp.append(resline) | |
break_safe = False | |
if verbose: | |
print("--" * 10) | |
print("\nthe full response is:\n") | |
print("\n".join(fn_resp)) | |
print("--" * 10) | |
return fn_resp | |