Peter commited on
Commit
235585a
1 Parent(s): 9a9e8f9

📝 docstrings and upd inputs

Browse files

Signed-off-by: Peter <74869040+pszemraj@users.noreply.github.com>

Files changed (1) hide show
  1. app.py +22 -22
app.py CHANGED
@@ -45,15 +45,15 @@ cwd = Path.cwd()
45
  my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
46
 
47
 
48
- def chat(prompt_message, temperature=0.7, top_p=0.95, top_k=50):
49
  """
50
- chat - helper function that makes the whole gradio thing work.
51
 
52
- Args:
53
- trivia_query (str): the question to ask the bot
54
-
55
- Returns:
56
- [str]: the bot's response
57
  """
58
  history = []
59
  response = ask_gpt(
@@ -78,26 +78,26 @@ def ask_gpt(
78
  chat_pipe,
79
  speaker="person alpha",
80
  responder="person beta",
81
- max_len=96,
 
82
  top_p=0.95,
83
  top_k=25,
84
  temperature=0.6,
85
- ):
86
  """
87
-
88
- ask_gpt - a function that takes in a prompt and generates a response using the pipeline. This interacts the discussion function.
89
-
90
- Parameters:
91
- message (str): the question to ask the bot
92
- chat_pipe (str): the chat_pipe to use for the bot (default: "pszemraj/Ballpark-Trivia-XL")
93
- speaker (str): the name of the speaker (default: "person alpha")
94
- responder (str): the name of the responder (default: "person beta")
95
- max_len (int): the maximum length of the response (default: 128)
96
- top_p (float): the top probability threshold (default: 0.95)
97
- top_k (int): the top k threshold (default: 50)
98
- temperature (float): the temperature of the response (default: 0.7)
99
  """
100
-
101
  st = time.perf_counter()
102
  prompt = clean(message) # clean user input
103
  prompt = prompt.strip() # get rid of any extra whitespace
 
45
  my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
46
 
47
 
48
+ def chat(prompt_message, temperature: float = 0.6, top_p: float = 0.95, top_k: int = 25)-> str:
49
  """
50
+ chat - the main function for the chatbot. This is the function that is called when the user
51
 
52
+ :param _type_ prompt_message: the message to send to the model
53
+ :param float temperature: the temperature value for the model, defaults to 0.6
54
+ :param float top_p: the top_p value for the model, defaults to 0.95
55
+ :param int top_k: the top_k value for the model, defaults to 25
56
+ :return str: the response from the model
57
  """
58
  history = []
59
  response = ask_gpt(
 
78
  chat_pipe,
79
  speaker="person alpha",
80
  responder="person beta",
81
+ min_length=4,
82
+ max_length=64,
83
  top_p=0.95,
84
  top_k=25,
85
  temperature=0.6,
86
+ ) -> str:
87
  """
88
+ ask_gpt - helper function that asks the GPT model a question and returns the response
89
+
90
+ :param str message: the question to ask the model
91
+ :param chat_pipe: the pipeline object for the model, created by the pipeline() function
92
+ :param str speaker: the name of the speaker, defaults to "person alpha"
93
+ :param str responder: the name of the responder, defaults to "person beta"
94
+ :param int min_length: the minimum length of the response, defaults to 4
95
+ :param int max_length: the maximum length of the response, defaults to 64
96
+ :param float top_p: the top_p value for the model, defaults to 0.95
97
+ :param int top_k: the top_k value for the model, defaults to 25
98
+ :param float temperature: the temperature value for the model, defaults to 0.6
99
+ :return str: the response from the model
100
  """
 
101
  st = time.perf_counter()
102
  prompt = clean(message) # clean user input
103
  prompt = prompt.strip() # get rid of any extra whitespace