Mimi commited on
Commit
4ad9750
·
1 Parent(s): 39a8462

Added temporary decoding kwargs"

Browse files
Files changed (1) hide show
  1. agent.py +19 -1
agent.py CHANGED
@@ -107,6 +107,18 @@ new_chat_template = """{{- bos_token }}
107
  {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
108
  {%- endif %}"""
109
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  class Naomi:
111
  def __init__(self, **kwargs):
112
  # init dataclasses
@@ -129,11 +141,17 @@ class Naomi:
129
 
130
  def respond(self, user_input: dict, **kwargs):
131
  """ Called during stream. """
 
 
 
 
 
 
132
  # user msg handling
133
  format_user_input = self.model.tokenizer_.hf_tokenizer.apply_chat_template([user_input], tokenize=False, add_generation_prompt=False)
134
  self.chat_history += format_user_input
135
  # agent msg results + clean
136
- response = self.model(self.chat_history, **kwargs)
137
  output = "".join(response['choices'][0]['text'].split('\n\n')[1:])
138
  # update history
139
  self.chat_history += self.model.tokenizer_.hf_tokenizer.apply_chat_template([{'role': 'assistant', 'content': output}], tokenize=False, add_generation_prompt=False)
 
107
  {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
108
  {%- endif %}"""
109
 
110
+ DECODE_ARGS = dict(
111
+ max_tokens=300,
112
+ temperature=1.5,
113
+ top_p=0.2,
114
+ frequency_penalty=0.3,
115
+ presence_penalty=0.5,
116
+ seed=42,
117
+ mirostat_tau=0.3,
118
+ mirostat_eta=0.0001,
119
+ )
120
+ MAX_TOKENS_INCREMENT = 50
121
+
122
  class Naomi:
123
  def __init__(self, **kwargs):
124
  # init dataclasses
 
141
 
142
  def respond(self, user_input: dict, **kwargs):
143
  """ Called during stream. """
144
+ max_tokens = DECODE_ARGS['max_tokens']
145
+ DECODE_ARGS['max_tokens'] = max_tokens + MAX_TOKENS_INCREMENT
146
+
147
+ if kwargs:
148
+ DECODE_ARGS.update(kwargs)
149
+
150
  # user msg handling
151
  format_user_input = self.model.tokenizer_.hf_tokenizer.apply_chat_template([user_input], tokenize=False, add_generation_prompt=False)
152
  self.chat_history += format_user_input
153
  # agent msg results + clean
154
+ response = self.model(self.chat_history, **DECODE_ARGS)
155
  output = "".join(response['choices'][0]['text'].split('\n\n')[1:])
156
  # update history
157
  self.chat_history += self.model.tokenizer_.hf_tokenizer.apply_chat_template([{'role': 'assistant', 'content': output}], tokenize=False, add_generation_prompt=False)