Spaces:
Running
Running
pseudotensor
commited on
Commit
•
b38cab2
1
Parent(s):
8a46296
Update with h2oGPT hash c0762b9528f67797cf2d2ec3a99ae7880d324fec
Browse files
app.py
CHANGED
@@ -2,10 +2,8 @@ import functools
|
|
2 |
import inspect
|
3 |
import sys
|
4 |
import os
|
5 |
-
import time
|
6 |
import traceback
|
7 |
import typing
|
8 |
-
import filelock
|
9 |
from utils import set_seed, flatten_list, clear_torch_cache, system_info_print, zip_data, save_generate_output
|
10 |
|
11 |
SEED = 1236
|
@@ -60,7 +58,7 @@ def main(
|
|
60 |
|
61 |
llama_type: bool = None,
|
62 |
debug: bool = False,
|
63 |
-
|
64 |
share: bool = True,
|
65 |
local_files_only: bool = False,
|
66 |
resume_download: bool = True,
|
@@ -114,7 +112,7 @@ def main(
|
|
114 |
if is_hf:
|
115 |
# must override share if in spaces
|
116 |
share = False
|
117 |
-
|
118 |
|
119 |
# get defaults
|
120 |
model_lower = base_model.lower()
|
@@ -182,7 +180,7 @@ def main(
|
|
182 |
if not eval_sharegpt_as_output:
|
183 |
model, tokenizer, device = get_model(**locals())
|
184 |
model_state = [model, tokenizer, device, base_model]
|
185 |
-
fun = partial(evaluate, model_state, debug=debug, chat=chat,
|
186 |
else:
|
187 |
assert eval_sharegpt_prompts_only > 0
|
188 |
|
@@ -816,7 +814,7 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
|
816 |
file_output = gr.File()
|
817 |
|
818 |
# Get flagged data
|
819 |
-
zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['
|
820 |
zip_btn.click(zip_data1, inputs=None, outputs=file_output)
|
821 |
|
822 |
def check_admin_pass(x):
|
@@ -1143,7 +1141,7 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
|
1143 |
|
1144 |
|
1145 |
input_args_list = ['model_state']
|
1146 |
-
inputs_kwargs_list = ['debug', 'chat', '
|
1147 |
|
1148 |
|
1149 |
def get_inputs_list(inputs_dict, model_lower):
|
@@ -1206,7 +1204,7 @@ def evaluate(
|
|
1206 |
src_lang=None,
|
1207 |
tgt_lang=None,
|
1208 |
debug=False,
|
1209 |
-
|
1210 |
chat=False,
|
1211 |
hard_stop_list=None,
|
1212 |
sanitize_bot_response=True,
|
@@ -1269,7 +1267,7 @@ def evaluate(
|
|
1269 |
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
1270 |
# stopping only starts once output is beyond prompt
|
1271 |
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
1272 |
-
stop_words = [human, bot]
|
1273 |
encounters = [1, 2]
|
1274 |
elif prompt_type == 'instruct_vicuna':
|
1275 |
# even below is not enough, generic strings and many ways to encode
|
@@ -1300,6 +1298,9 @@ def evaluate(
|
|
1300 |
# avoid padding in front of tokens
|
1301 |
if tokenizer.pad_token:
|
1302 |
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
|
|
|
|
|
|
1303 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
|
1304 |
else:
|
1305 |
stopping_criteria = StoppingCriteriaList()
|
@@ -1420,14 +1421,16 @@ def evaluate(
|
|
1420 |
raise StopIteration
|
1421 |
yield prompter.get_response(decoded_output, prompt=inputs_decoded,
|
1422 |
sanitize_bot_response=sanitize_bot_response)
|
1423 |
-
if
|
1424 |
-
save_generate_output(output=decoded_output, base_model=base_model,
|
1425 |
-
return
|
1426 |
else:
|
1427 |
outputs = model.generate(**gen_kwargs)
|
1428 |
outputs = [decoder(s) for s in outputs.sequences]
|
1429 |
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
1430 |
sanitize_bot_response=sanitize_bot_response)
|
|
|
|
|
|
|
1431 |
|
1432 |
|
1433 |
def get_generate_params(model_lower, chat,
|
|
|
2 |
import inspect
|
3 |
import sys
|
4 |
import os
|
|
|
5 |
import traceback
|
6 |
import typing
|
|
|
7 |
from utils import set_seed, flatten_list, clear_torch_cache, system_info_print, zip_data, save_generate_output
|
8 |
|
9 |
SEED = 1236
|
|
|
58 |
|
59 |
llama_type: bool = None,
|
60 |
debug: bool = False,
|
61 |
+
save_dir: str = None,
|
62 |
share: bool = True,
|
63 |
local_files_only: bool = False,
|
64 |
resume_download: bool = True,
|
|
|
112 |
if is_hf:
|
113 |
# must override share if in spaces
|
114 |
share = False
|
115 |
+
save_dir = os.getenv('SAVE_DIR', save_dir)
|
116 |
|
117 |
# get defaults
|
118 |
model_lower = base_model.lower()
|
|
|
180 |
if not eval_sharegpt_as_output:
|
181 |
model, tokenizer, device = get_model(**locals())
|
182 |
model_state = [model, tokenizer, device, base_model]
|
183 |
+
fun = partial(evaluate, model_state, debug=debug, chat=chat, save_dir=save_dir)
|
184 |
else:
|
185 |
assert eval_sharegpt_prompts_only > 0
|
186 |
|
|
|
814 |
file_output = gr.File()
|
815 |
|
816 |
# Get flagged data
|
817 |
+
zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
|
818 |
zip_btn.click(zip_data1, inputs=None, outputs=file_output)
|
819 |
|
820 |
def check_admin_pass(x):
|
|
|
1141 |
|
1142 |
|
1143 |
input_args_list = ['model_state']
|
1144 |
+
inputs_kwargs_list = ['debug', 'chat', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
|
1145 |
|
1146 |
|
1147 |
def get_inputs_list(inputs_dict, model_lower):
|
|
|
1204 |
src_lang=None,
|
1205 |
tgt_lang=None,
|
1206 |
debug=False,
|
1207 |
+
save_dir=None,
|
1208 |
chat=False,
|
1209 |
hard_stop_list=None,
|
1210 |
sanitize_bot_response=True,
|
|
|
1267 |
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
1268 |
# stopping only starts once output is beyond prompt
|
1269 |
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
1270 |
+
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
1271 |
encounters = [1, 2]
|
1272 |
elif prompt_type == 'instruct_vicuna':
|
1273 |
# even below is not enough, generic strings and many ways to encode
|
|
|
1298 |
# avoid padding in front of tokens
|
1299 |
if tokenizer.pad_token:
|
1300 |
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
1301 |
+
# handle fake \n added
|
1302 |
+
stop_words_ids = [x[1:] if y[0] == '\n' else x for x,y in zip(stop_words_ids, stop_words)]
|
1303 |
+
# build stopper
|
1304 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
|
1305 |
else:
|
1306 |
stopping_criteria = StoppingCriteriaList()
|
|
|
1421 |
raise StopIteration
|
1422 |
yield prompter.get_response(decoded_output, prompt=inputs_decoded,
|
1423 |
sanitize_bot_response=sanitize_bot_response)
|
1424 |
+
if save_dir and decoded_output:
|
1425 |
+
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
|
|
1426 |
else:
|
1427 |
outputs = model.generate(**gen_kwargs)
|
1428 |
outputs = [decoder(s) for s in outputs.sequences]
|
1429 |
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
1430 |
sanitize_bot_response=sanitize_bot_response)
|
1431 |
+
if save_dir and outputs and len(outputs) >= 1:
|
1432 |
+
decoded_output = prompt + outputs[0]
|
1433 |
+
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
1434 |
|
1435 |
|
1436 |
def get_generate_params(model_lower, chat,
|
utils.py
CHANGED
@@ -118,33 +118,31 @@ def _zip_data(root_dirs=None, zip_path='data.zip', base_dir='./'):
|
|
118 |
return "data.zip"
|
119 |
|
120 |
|
121 |
-
def save_generate_output(output=None, base_model=None,
|
122 |
try:
|
123 |
-
return _save_generate_output(output=output, base_model=base_model,
|
124 |
except Exception as e:
|
125 |
traceback.print_exc()
|
126 |
print('Exception in saving: %s' % str(e))
|
127 |
|
128 |
|
129 |
-
def _save_generate_output(output=None, base_model=None,
|
130 |
"""
|
131 |
-
Save conversation to .json, row by row
|
|
|
132 |
Appends if file exists
|
133 |
"""
|
134 |
-
assert
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
os.remove(as_file)
|
139 |
-
os.makedirs(json_file_path, exist_ok=True)
|
140 |
-
json_file_file = os.path.join(json_file_path, 'save.json')
|
141 |
import json
|
142 |
if output[-10:] == '\n\n<human>:':
|
143 |
# remove trailing <human>:
|
144 |
output = output[:-10]
|
145 |
-
with filelock.FileLock("
|
146 |
# lock logging in case have concurrency
|
147 |
-
with open(
|
148 |
# just add [ at start, and ] at end, and have proper JSON dataset
|
149 |
f.write(
|
150 |
" " + json.dumps(
|
|
|
118 |
return "data.zip"
|
119 |
|
120 |
|
121 |
+
def save_generate_output(output=None, base_model=None, save_dir=None):
|
122 |
try:
|
123 |
+
return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir)
|
124 |
except Exception as e:
|
125 |
traceback.print_exc()
|
126 |
print('Exception in saving: %s' % str(e))
|
127 |
|
128 |
|
129 |
+
def _save_generate_output(output=None, base_model=None, save_dir=None):
|
130 |
"""
|
131 |
+
Save conversation to .json, row by row.
|
132 |
+
json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
|
133 |
Appends if file exists
|
134 |
"""
|
135 |
+
assert save_dir, "save_dir must be provided"
|
136 |
+
if os.path.exists(save_dir) and not os.path.isdir(save_dir):
|
137 |
+
raise RuntimeError("save_dir already exists and is not a directory!")
|
138 |
+
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
|
|
|
139 |
import json
|
140 |
if output[-10:] == '\n\n<human>:':
|
141 |
# remove trailing <human>:
|
142 |
output = output[:-10]
|
143 |
+
with filelock.FileLock("save_dir.lock"):
|
144 |
# lock logging in case have concurrency
|
145 |
+
with open(os.path.join(save_dir, "history.json"), "a") as f:
|
146 |
# just add [ at start, and ] at end, and have proper JSON dataset
|
147 |
f.write(
|
148 |
" " + json.dumps(
|