pseudotensor commited on
Commit
b38cab2
1 Parent(s): 8a46296

Update with h2oGPT hash c0762b9528f67797cf2d2ec3a99ae7880d324fec

Browse files
Files changed (2) hide show
  1. app.py +15 -12
  2. utils.py +11 -13
app.py CHANGED
@@ -2,10 +2,8 @@ import functools
2
  import inspect
3
  import sys
4
  import os
5
- import time
6
  import traceback
7
  import typing
8
- import filelock
9
  from utils import set_seed, flatten_list, clear_torch_cache, system_info_print, zip_data, save_generate_output
10
 
11
  SEED = 1236
@@ -60,7 +58,7 @@ def main(
60
 
61
  llama_type: bool = None,
62
  debug: bool = False,
63
- save_path: str = None,
64
  share: bool = True,
65
  local_files_only: bool = False,
66
  resume_download: bool = True,
@@ -114,7 +112,7 @@ def main(
114
  if is_hf:
115
  # must override share if in spaces
116
  share = False
117
- save_path = os.getenv('SAVE_PATH')
118
 
119
  # get defaults
120
  model_lower = base_model.lower()
@@ -182,7 +180,7 @@ def main(
182
  if not eval_sharegpt_as_output:
183
  model, tokenizer, device = get_model(**locals())
184
  model_state = [model, tokenizer, device, base_model]
185
- fun = partial(evaluate, model_state, debug=debug, chat=chat, save_path=save_path)
186
  else:
187
  assert eval_sharegpt_prompts_only > 0
188
 
@@ -816,7 +814,7 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
816
  file_output = gr.File()
817
 
818
  # Get flagged data
819
- zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_path']])
820
  zip_btn.click(zip_data1, inputs=None, outputs=file_output)
821
 
822
  def check_admin_pass(x):
@@ -1143,7 +1141,7 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
1143
 
1144
 
1145
  input_args_list = ['model_state']
1146
- inputs_kwargs_list = ['debug', 'chat', 'save_path', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
1147
 
1148
 
1149
  def get_inputs_list(inputs_dict, model_lower):
@@ -1206,7 +1204,7 @@ def evaluate(
1206
  src_lang=None,
1207
  tgt_lang=None,
1208
  debug=False,
1209
- save_path=None,
1210
  chat=False,
1211
  hard_stop_list=None,
1212
  sanitize_bot_response=True,
@@ -1269,7 +1267,7 @@ def evaluate(
1269
  # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
1270
  # stopping only starts once output is beyond prompt
1271
  # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
1272
- stop_words = [human, bot]
1273
  encounters = [1, 2]
1274
  elif prompt_type == 'instruct_vicuna':
1275
  # even below is not enough, generic strings and many ways to encode
@@ -1300,6 +1298,9 @@ def evaluate(
1300
  # avoid padding in front of tokens
1301
  if tokenizer.pad_token:
1302
  stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
 
 
 
1303
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
1304
  else:
1305
  stopping_criteria = StoppingCriteriaList()
@@ -1420,14 +1421,16 @@ def evaluate(
1420
  raise StopIteration
1421
  yield prompter.get_response(decoded_output, prompt=inputs_decoded,
1422
  sanitize_bot_response=sanitize_bot_response)
1423
- if save_path and decoded_output:
1424
- save_generate_output(output=decoded_output, base_model=base_model, json_file_path=save_path)
1425
- return
1426
  else:
1427
  outputs = model.generate(**gen_kwargs)
1428
  outputs = [decoder(s) for s in outputs.sequences]
1429
  yield prompter.get_response(outputs, prompt=inputs_decoded,
1430
  sanitize_bot_response=sanitize_bot_response)
 
 
 
1431
 
1432
 
1433
  def get_generate_params(model_lower, chat,
 
2
  import inspect
3
  import sys
4
  import os
 
5
  import traceback
6
  import typing
 
7
  from utils import set_seed, flatten_list, clear_torch_cache, system_info_print, zip_data, save_generate_output
8
 
9
  SEED = 1236
 
58
 
59
  llama_type: bool = None,
60
  debug: bool = False,
61
+ save_dir: str = None,
62
  share: bool = True,
63
  local_files_only: bool = False,
64
  resume_download: bool = True,
 
112
  if is_hf:
113
  # must override share if in spaces
114
  share = False
115
+ save_dir = os.getenv('SAVE_DIR', save_dir)
116
 
117
  # get defaults
118
  model_lower = base_model.lower()
 
180
  if not eval_sharegpt_as_output:
181
  model, tokenizer, device = get_model(**locals())
182
  model_state = [model, tokenizer, device, base_model]
183
+ fun = partial(evaluate, model_state, debug=debug, chat=chat, save_dir=save_dir)
184
  else:
185
  assert eval_sharegpt_prompts_only > 0
186
 
 
814
  file_output = gr.File()
815
 
816
  # Get flagged data
817
+ zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
818
  zip_btn.click(zip_data1, inputs=None, outputs=file_output)
819
 
820
  def check_admin_pass(x):
 
1141
 
1142
 
1143
  input_args_list = ['model_state']
1144
+ inputs_kwargs_list = ['debug', 'chat', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
1145
 
1146
 
1147
  def get_inputs_list(inputs_dict, model_lower):
 
1204
  src_lang=None,
1205
  tgt_lang=None,
1206
  debug=False,
1207
+ save_dir=None,
1208
  chat=False,
1209
  hard_stop_list=None,
1210
  sanitize_bot_response=True,
 
1267
  # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
1268
  # stopping only starts once output is beyond prompt
1269
  # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
1270
+ stop_words = [human, bot, '\n' + human, '\n' + bot]
1271
  encounters = [1, 2]
1272
  elif prompt_type == 'instruct_vicuna':
1273
  # even below is not enough, generic strings and many ways to encode
 
1298
  # avoid padding in front of tokens
1299
  if tokenizer.pad_token:
1300
  stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
1301
+ # handle fake \n added
1302
+ stop_words_ids = [x[1:] if y[0] == '\n' else x for x,y in zip(stop_words_ids, stop_words)]
1303
+ # build stopper
1304
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
1305
  else:
1306
  stopping_criteria = StoppingCriteriaList()
 
1421
  raise StopIteration
1422
  yield prompter.get_response(decoded_output, prompt=inputs_decoded,
1423
  sanitize_bot_response=sanitize_bot_response)
1424
+ if save_dir and decoded_output:
1425
+ save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
 
1426
  else:
1427
  outputs = model.generate(**gen_kwargs)
1428
  outputs = [decoder(s) for s in outputs.sequences]
1429
  yield prompter.get_response(outputs, prompt=inputs_decoded,
1430
  sanitize_bot_response=sanitize_bot_response)
1431
+ if save_dir and outputs and len(outputs) >= 1:
1432
+ decoded_output = prompt + outputs[0]
1433
+ save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
1434
 
1435
 
1436
  def get_generate_params(model_lower, chat,
utils.py CHANGED
@@ -118,33 +118,31 @@ def _zip_data(root_dirs=None, zip_path='data.zip', base_dir='./'):
118
  return "data.zip"
119
 
120
 
121
- def save_generate_output(output=None, base_model=None, json_file_path=None):
122
  try:
123
- return _save_generate_output(output=output, base_model=base_model, json_file_path=json_file_path)
124
  except Exception as e:
125
  traceback.print_exc()
126
  print('Exception in saving: %s' % str(e))
127
 
128
 
129
- def _save_generate_output(output=None, base_model=None, json_file_path=None):
130
  """
131
- Save conversation to .json, row by row
 
132
  Appends if file exists
133
  """
134
- assert isinstance(json_file_path, str), "must provide save_path"
135
- as_file = os.path.normpath(json_file_path)
136
- if os.path.isfile(as_file):
137
- # protection if had file there before
138
- os.remove(as_file)
139
- os.makedirs(json_file_path, exist_ok=True)
140
- json_file_file = os.path.join(json_file_path, 'save.json')
141
  import json
142
  if output[-10:] == '\n\n<human>:':
143
  # remove trailing <human>:
144
  output = output[:-10]
145
- with filelock.FileLock("save_path.lock"):
146
  # lock logging in case have concurrency
147
- with open(json_file_file, "a") as f:
148
  # just add [ at start, and ] at end, and have proper JSON dataset
149
  f.write(
150
  " " + json.dumps(
 
118
  return "data.zip"
119
 
120
 
121
+ def save_generate_output(output=None, base_model=None, save_dir=None):
122
  try:
123
+ return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir)
124
  except Exception as e:
125
  traceback.print_exc()
126
  print('Exception in saving: %s' % str(e))
127
 
128
 
129
+ def _save_generate_output(output=None, base_model=None, save_dir=None):
130
  """
131
+ Save conversation to .json, row by row.
132
+ json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
133
  Appends if file exists
134
  """
135
+ assert save_dir, "save_dir must be provided"
136
+ if os.path.exists(save_dir) and not os.path.isdir(save_dir):
137
+ raise RuntimeError("save_dir already exists and is not a directory!")
138
+ os.makedirs(save_dir, exist_ok=True)
 
 
 
139
  import json
140
  if output[-10:] == '\n\n<human>:':
141
  # remove trailing <human>:
142
  output = output[:-10]
143
+ with filelock.FileLock("save_dir.lock"):
144
  # lock logging in case have concurrency
145
+ with open(os.path.join(save_dir, "history.json"), "a") as f:
146
  # just add [ at start, and ] at end, and have proper JSON dataset
147
  f.write(
148
  " " + json.dumps(