pseudotensor commited on
Commit
b43c18e
1 Parent(s): 5b1d132

Update with h2oGPT hash 05d3ad444971c24fb021ea80c27f867c7a953699

Browse files
Files changed (7) hide show
  1. client_test.py +4 -2
  2. finetune.py +60 -10
  3. generate.py +98 -83
  4. gradio_runner.py +26 -10
  5. prompter.py +6 -5
  6. requirements.txt +1 -1
  7. stopping.py +49 -6
client_test.py CHANGED
@@ -53,13 +53,16 @@ def get_client():
53
 
54
 
55
  def test_client_basic():
 
 
 
 
56
  instruction = '' # only for chat=True
57
  iinput = '' # only for chat=True
58
  context = ''
59
  # streaming output is supported, loops over and outputs each generation in streaming mode
60
  # but leave stream_output=False for simple input/output mode
61
  stream_output = False
62
- prompt_type = 'human_bot'
63
  temperature = 0.1
64
  top_p = 0.75
65
  top_k = 40
@@ -73,7 +76,6 @@ def test_client_basic():
73
  do_sample = True
74
  # only these 2 below used if pass chat=False
75
  chat = False
76
- instruction_nochat = "Who are you?"
77
  iinput_nochat = ''
78
 
79
  args = [instruction,
 
53
 
54
 
55
  def test_client_basic():
56
+ return run_client_basic(instruction_nochat='Who are you?', prompt_type='human_bot')
57
+
58
+
59
+ def run_client_basic(instruction_nochat, prompt_type):
60
  instruction = '' # only for chat=True
61
  iinput = '' # only for chat=True
62
  context = ''
63
  # streaming output is supported, loops over and outputs each generation in streaming mode
64
  # but leave stream_output=False for simple input/output mode
65
  stream_output = False
 
66
  temperature = 0.1
67
  top_p = 0.75
68
  top_k = 40
 
76
  do_sample = True
77
  # only these 2 below used if pass chat=False
78
  chat = False
 
79
  iinput_nochat = ''
80
 
81
  args = [instruction,
finetune.py CHANGED
@@ -28,6 +28,8 @@ class PromptType(Enum):
28
  instruct_vicuna = 7
29
  instruct_with_end = 8
30
  human_bot_orig = 9
 
 
31
 
32
 
33
  prompt_type_to_model_name = {
@@ -46,6 +48,14 @@ prompt_type_to_model_name = {
46
  'philschmid/flan-t5-base-samsum',
47
  'gpt2',
48
  'distilgpt2',
 
 
 
 
 
 
 
 
49
  ],
50
  'instruct': [],
51
  'instruct_with_end': ['databricks/dolly-v2-12b'],
@@ -61,14 +71,12 @@ prompt_type_to_model_name = {
61
  'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
62
  'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
63
  'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
 
64
  }
65
 
66
  inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
67
  inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
68
 
69
- human = '<human>:'
70
- bot = "<bot>:"
71
-
72
  prompt_types_strings = []
73
  for p in PromptType:
74
  prompt_types_strings.extend([p.name])
@@ -277,8 +285,13 @@ def train(
277
  layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
278
  )
279
 
280
- from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, utils
281
- lora_mappings = utils.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
 
 
 
 
 
282
  lora_mappings['distilgpt2'] = ["c_attn"]
283
 
284
  if lora_weights:
@@ -730,10 +743,10 @@ def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=F
730
  assert prompt_type is not None
731
  assert cutoff_len is not None
732
  assert tokenizer is not None
733
- full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
734
  tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
735
  if not train_on_inputs:
736
- user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
737
  tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
738
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
739
  if add_eos_token:
@@ -752,9 +765,11 @@ def get_prompt(prompt_type, chat, context, reduced):
752
  if prompt_type in [-1, "-1", "plain"]:
753
  promptA = promptB = PreInstruct = PreInput = PreResponse = ''
754
  terminate_response = []
 
755
  elif prompt_type == 'simple_instruct':
756
  promptA = promptB = PreInstruct = PreInput = PreResponse = None
757
  terminate_response = []
 
758
  elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
759
  promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
760
  promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
@@ -774,6 +789,7 @@ def get_prompt(prompt_type, chat, context, reduced):
774
  terminate_response = ['### End']
775
  else:
776
  terminate_response = None
 
777
  elif prompt_type in [1, "1", "quality"]:
778
  promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
779
  promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
@@ -790,7 +806,10 @@ def get_prompt(prompt_type, chat, context, reduced):
790
  ### Response:
791
  """
792
  terminate_response = None
 
793
  elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
 
 
794
  if reduced or context or prompt_type in [2, "2", "human_bot"]:
795
  preprompt = ''
796
  else:
@@ -819,6 +838,7 @@ Current Time: {}
819
  PreResponse = bot
820
 
821
  terminate_response = [start, PreResponse]
 
822
  elif prompt_type in [3, "3", "dai_faq"]:
823
  promptA = ''
824
  promptB = 'Answer the following Driverless AI question.\n'
@@ -833,11 +853,13 @@ Current Time: {}
833
  ### Driverless AI documentation answer:
834
  """
835
  terminate_response = ['\n\n']
 
836
  elif prompt_type in [5, "5", "summarize"]:
837
  promptA = promptB = PreInput = ''
838
  PreInstruct = '## Main Text\n\n'
839
  PreResponse = '\n\n## Summary\n\n'
840
  terminate_response = None
 
841
  elif prompt_type in [6, "6", "instruct_vicuna"]:
842
  promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
843
  "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
@@ -852,10 +874,37 @@ Current Time: {}
852
  ### Assistant:
853
  """
854
  terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855
  else:
856
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
857
 
858
- return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response
859
 
860
 
861
  def generate_prompt(data_point, prompt_type, chat, reduced):
@@ -867,7 +916,8 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
867
  output = data_point.get('output')
868
  prompt_type = data_point.get('prompt_type', prompt_type)
869
  assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
870
- promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
 
871
 
872
  prompt = context if not reduced else ''
873
 
@@ -919,7 +969,7 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
919
  if output:
920
  prompt += f"""{output}"""
921
 
922
- return prompt, pre_response, terminate_response
923
 
924
 
925
  def inject_newline(prompt_type, prompt):
 
28
  instruct_vicuna = 7
29
  instruct_with_end = 8
30
  human_bot_orig = 9
31
+ prompt_answer = 10
32
+ open_assistant = 11
33
 
34
 
35
  prompt_type_to_model_name = {
 
48
  'philschmid/flan-t5-base-samsum',
49
  'gpt2',
50
  'distilgpt2',
51
+ 'mosaicml/mpt-7b-storywriter',
52
+ 'mosaicml/mpt-7b-instruct', # internal code handles instruct
53
+ 'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
54
+ ],
55
+ 'prompt_answer': [
56
+ 'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
57
+ 'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
58
+ 'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
59
  ],
60
  'instruct': [],
61
  'instruct_with_end': ['databricks/dolly-v2-12b'],
 
71
  'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
72
  'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
73
  'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
74
+ "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
75
  }
76
 
77
  inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
78
  inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
79
 
 
 
 
80
  prompt_types_strings = []
81
  for p in PromptType:
82
  prompt_types_strings.extend([p.name])
 
285
  layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
286
  )
287
 
288
+ from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
289
+ try:
290
+ from peft import utils
291
+ lora_mappings = utils.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
292
+ except AttributeError:
293
+ from peft import mapping
294
+ lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
295
  lora_mappings['distilgpt2'] = ["c_attn"]
296
 
297
  if lora_weights:
 
743
  assert prompt_type is not None
744
  assert cutoff_len is not None
745
  assert tokenizer is not None
746
+ full_prompt, _, _, _ = generate_prompt(data_point, prompt_type, False, False)
747
  tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
748
  if not train_on_inputs:
749
+ user_prompt, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
750
  tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
751
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
752
  if add_eos_token:
 
765
  if prompt_type in [-1, "-1", "plain"]:
766
  promptA = promptB = PreInstruct = PreInput = PreResponse = ''
767
  terminate_response = []
768
+ chat_sep = ''
769
  elif prompt_type == 'simple_instruct':
770
  promptA = promptB = PreInstruct = PreInput = PreResponse = None
771
  terminate_response = []
772
+ chat_sep = '\n'
773
  elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
774
  promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
775
  promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
 
789
  terminate_response = ['### End']
790
  else:
791
  terminate_response = None
792
+ chat_sep = '\n'
793
  elif prompt_type in [1, "1", "quality"]:
794
  promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
795
  promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
 
806
  ### Response:
807
  """
808
  terminate_response = None
809
+ chat_sep = '\n'
810
  elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
811
+ human = '<human>:'
812
+ bot = "<bot>:"
813
  if reduced or context or prompt_type in [2, "2", "human_bot"]:
814
  preprompt = ''
815
  else:
 
838
  PreResponse = bot
839
 
840
  terminate_response = [start, PreResponse]
841
+ chat_sep = '\n'
842
  elif prompt_type in [3, "3", "dai_faq"]:
843
  promptA = ''
844
  promptB = 'Answer the following Driverless AI question.\n'
 
853
  ### Driverless AI documentation answer:
854
  """
855
  terminate_response = ['\n\n']
856
+ chat_sep = terminate_response
857
  elif prompt_type in [5, "5", "summarize"]:
858
  promptA = promptB = PreInput = ''
859
  PreInstruct = '## Main Text\n\n'
860
  PreResponse = '\n\n## Summary\n\n'
861
  terminate_response = None
862
+ chat_sep = '\n'
863
  elif prompt_type in [6, "6", "instruct_vicuna"]:
864
  promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
865
  "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
 
874
  ### Assistant:
875
  """
876
  terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
877
+ chat_sep = '\n'
878
+ elif prompt_type in [10, "10", "prompt_answer"]:
879
+ preprompt = ''
880
+ prompt_tokens = "<|prompt|>"
881
+ answer_tokens = "<|answer|>"
882
+ start = prompt_tokens
883
+ promptB = promptA = '%s%s' % (preprompt, start)
884
+ PreInstruct = ""
885
+ PreInput = None
886
+ PreResponse = answer_tokens
887
+ eos = '<|endoftext|>' # neox eos
888
+ terminate_response = [start, PreResponse, eos]
889
+ chat_sep = eos
890
+ elif prompt_type in [11, "11", "open_assistant"]:
891
+ # From added_tokens.json
892
+ preprompt = ''
893
+ prompt_tokens = "<|prompter|>"
894
+ answer_tokens = "<|assistant|>"
895
+ start = prompt_tokens
896
+ promptB = promptA = '%s%s' % (preprompt, start)
897
+ PreInstruct = ""
898
+ PreInput = None
899
+ PreResponse = answer_tokens
900
+ pend = "<|prefix_end|>"
901
+ eos = "</s>"
902
+ terminate_response = [start, PreResponse, pend, eos]
903
+ chat_sep = eos
904
  else:
905
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
906
 
907
+ return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep
908
 
909
 
910
  def generate_prompt(data_point, prompt_type, chat, reduced):
 
916
  output = data_point.get('output')
917
  prompt_type = data_point.get('prompt_type', prompt_type)
918
  assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
919
+ promptA, promptB, PreInstruct, PreInput, PreResponse, \
920
+ terminate_response, chat_sep = get_prompt(prompt_type, chat, context, reduced)
921
 
922
  prompt = context if not reduced else ''
923
 
 
969
  if output:
970
  prompt += f"""{output}"""
971
 
972
+ return prompt, pre_response, terminate_response, chat_sep
973
 
974
 
975
  def inject_newline(prompt_type, prompt):
generate.py CHANGED
@@ -9,7 +9,7 @@ from datetime import datetime
9
  import filelock
10
  import psutil
11
 
12
- from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread
13
 
14
  SEED = 1236
15
  set_seed(SEED)
@@ -22,13 +22,13 @@ import pandas as pd
22
  import fire
23
  import torch
24
  from peft import PeftModel
25
- from transformers import GenerationConfig, StoppingCriteriaList, AutoModel, TextIteratorStreamer
26
  from accelerate import init_empty_weights, infer_auto_device_map
27
 
28
  from prompter import Prompter
29
 
30
- from finetune import get_loaders, example_data_points, generate_prompt, human, bot, inv_prompt_type_to_model_lower
31
- from stopping import StoppingCriteriaSub
32
 
33
  eval_extra_columns = ['prompt', 'response', 'score']
34
 
@@ -62,6 +62,7 @@ def main(
62
  local_files_only: bool = False,
63
  resume_download: bool = True,
64
  use_auth_token: Union[str, bool] = False,
 
65
 
66
  src_lang: str = "English",
67
  tgt_lang: str = "Russian",
@@ -124,6 +125,7 @@ def main(
124
  :param local_files_only: whether to only use local files instead of doing to HF for models
125
  :param resume_download: whether to resume downloads from HF for models
126
  :param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
 
127
  :param src_lang: source languages to include if doing translation (None = all)
128
  :param tgt_lang: target languages to include if doing translation (None = all)
129
  :param gradio: whether to enable gradio, or to enable benchmark mode
@@ -168,15 +170,22 @@ def main(
168
 
169
  if is_public:
170
  input_lines = 1 # ensure set, for ease of use
171
- temperature = 0.2
172
- top_p = 0.85
173
- top_k = 70
174
- do_sample = True
 
 
 
 
 
175
  if is_low_mem:
176
- base_model = 'h2oai/h2ogpt-oasst1-512-12b'
177
- load_8bit = True
 
 
178
  else:
179
- base_model = 'h2oai/h2ogpt-oasst1-512-20b'
180
  if is_low_mem:
181
  load_8bit = True
182
  if is_hf:
@@ -229,6 +238,11 @@ def main(
229
  do_sample,
230
  )
231
 
 
 
 
 
 
232
  if not gradio:
233
  if eval_sharegpt_prompts_only > 0:
234
  # override default examples with shareGPT ones for human-level eval purposes only
@@ -416,7 +430,11 @@ def get_device():
416
 
417
  def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
418
  gpu_id=0,
419
- use_auth_token=False):
 
 
 
 
420
  """
421
  Ensure model gets on correct device
422
  :param base_model:
@@ -426,29 +444,47 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
426
  :param reward_type:
427
  :param gpu_id:
428
  :param use_auth_token:
 
 
 
429
  :return:
430
  """
431
  with init_empty_weights():
432
  from transformers import AutoConfig
433
- config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token)
434
- model = AutoModel.from_config(
435
- config,
436
- )
437
-
438
- # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
439
- # NOTE: Some models require avoiding sharding some layers,
440
- # then would pass no_split_module_classes and give list of those layers.
441
- device_map = infer_auto_device_map(
442
- model,
443
- dtype=torch.float16 if load_half else torch.float32,
444
- )
445
- if hasattr(model, 'model'):
446
- device_map_model = infer_auto_device_map(
447
- model.model,
 
 
 
 
 
 
 
 
448
  dtype=torch.float16 if load_half else torch.float32,
449
  )
450
- device_map.update(device_map_model)
451
- print('device_map: %s' % device_map, flush=True)
 
 
 
 
 
 
 
452
 
453
  n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
454
 
@@ -472,11 +508,13 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
472
  if load_in_8bit or not load_half:
473
  model = model_loader.from_pretrained(
474
  base_model,
 
475
  **model_kwargs,
476
  )
477
  else:
478
  model = model_loader.from_pretrained(
479
  base_model,
 
480
  **model_kwargs,
481
  ).half()
482
  return model
@@ -495,6 +533,7 @@ def get_model(
495
  local_files_only: bool = False,
496
  resume_download: bool = True,
497
  use_auth_token: Union[str, bool] = False,
 
498
  compile: bool = True,
499
  **kwargs,
500
  ):
@@ -513,6 +552,7 @@ def get_model(
513
  :param local_files_only: use local files instead of from HF
514
  :param resume_download: resume downloads from HF
515
  :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
 
516
  :param compile: whether to compile torch model
517
  :param kwargs:
518
  :return:
@@ -531,7 +571,8 @@ def get_model(
531
  )
532
 
533
  from transformers import AutoConfig
534
- config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token)
 
535
  llama_type_from_config = 'llama' in str(config).lower()
536
  llama_type_from_name = "llama" in base_model.lower()
537
  llama_type = llama_type_from_config or llama_type_from_name
@@ -548,6 +589,7 @@ def get_model(
548
  local_files_only=local_files_only,
549
  resume_download=resume_download,
550
  use_auth_token=use_auth_token,
 
551
  )
552
  else:
553
  tokenizer = tokenizer_loader
@@ -563,13 +605,18 @@ def get_model(
563
  model_kwargs = dict(local_files_only=local_files_only,
564
  torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
565
  resume_download=resume_download,
566
- use_auth_token=use_auth_token)
567
- if 'mbart-' not in base_model.lower():
 
 
568
  model_kwargs.update(dict(load_in_8bit=load_8bit,
569
  device_map={"": 0} if load_8bit and device == 'cuda' else "auto",
570
  ))
 
 
 
571
  if 'OpenAssistant/reward-model'.lower() in base_model.lower():
572
- # could put on other GPUs
573
  model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
574
  model_kwargs.pop('torch_dtype', None)
575
 
@@ -577,7 +624,10 @@ def get_model(
577
  with torch.device(device):
578
  if infer_devices:
579
  model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
580
- gpu_id=gpu_id, use_auth_token=use_auth_token)
 
 
 
581
  else:
582
  if load_half and not load_8bit:
583
  model = model_loader.from_pretrained(
@@ -599,6 +649,7 @@ def get_model(
599
  local_files_only=local_files_only,
600
  resume_download=resume_download,
601
  use_auth_token=use_auth_token,
 
602
  device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, # seems to be required
603
  )
604
  else:
@@ -614,6 +665,7 @@ def get_model(
614
  local_files_only=local_files_only,
615
  resume_download=resume_download,
616
  use_auth_token=use_auth_token,
 
617
  device_map="auto",
618
  )
619
  if load_half:
@@ -782,49 +834,7 @@ def evaluate(
782
  if chat:
783
  # override, ignore user change
784
  num_return_sequences = 1
785
- if prompt_type in ['human_bot', 'instruct_vicuna', 'instruct_with_end']:
786
- if prompt_type == 'human_bot':
787
- # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
788
- # stopping only starts once output is beyond prompt
789
- # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
790
- stop_words = [human, bot, '\n' + human, '\n' + bot]
791
- encounters = [1, 2]
792
- elif prompt_type == 'instruct_vicuna':
793
- # even below is not enough, generic strings and many ways to encode
794
- stop_words = [
795
- '### Human:',
796
- """
797
- ### Human:""",
798
- """
799
- ### Human:
800
- """,
801
- '### Assistant:',
802
- """
803
- ### Assistant:""",
804
- """
805
- ### Assistant:
806
- """,
807
- ]
808
- encounters = [1, 2]
809
- else:
810
- # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
811
- stop_words = ['### End']
812
- encounters = [1]
813
- stop_words_ids = [
814
- tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
815
- # handle single token case
816
- stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
817
- stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
818
- # avoid padding in front of tokens
819
- if tokenizer.pad_token:
820
- stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
821
- # handle fake \n added
822
- stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
823
- # build stopper
824
- stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
825
- else:
826
- stopping_criteria = StoppingCriteriaList()
827
-
828
  # help to avoid errors like:
829
  # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
830
  # RuntimeError: expected scalar type Half but found Float
@@ -903,7 +913,10 @@ def evaluate(
903
  prompt = inputs_decoded
904
  elif inputs_decoded_raw == prompt:
905
  # some models specify special tokens that are part of normal prompt, so can't skip them
906
- inputs_decoded_raw = inputs_decoded
 
 
 
907
  decoder = decoder_raw
908
  else:
909
  print("WARNING: Special characters in prompt", flush=True)
@@ -1046,6 +1059,7 @@ def get_generate_params(model_lower, chat,
1046
 
1047
  if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
1048
  prompt_type = inv_prompt_type_to_model_lower[model_lower]
 
1049
 
1050
  # examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
1051
  if show_examples is None:
@@ -1104,7 +1118,8 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa
1104
  placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
1105
  placeholder_input = ""
1106
  if model_lower:
1107
- prompt_type = prompt_type or 'human_bot'
 
1108
  else:
1109
  prompt_type = ''
1110
  examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
@@ -1133,9 +1148,9 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa
1133
  num_return_sequences = min(num_beams, num_return_sequences or 1)
1134
  do_sample = False if do_sample is None else do_sample
1135
  else:
1136
- temperature = 0.2 if temperature is None else temperature
1137
- top_p = 0.85 if top_p is None else top_p
1138
- top_k = 70 if top_k is None else top_k
1139
  if chat:
1140
  num_beams = num_beams or 1
1141
  else:
@@ -1143,7 +1158,7 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa
1143
  max_new_tokens = max_new_tokens or 256
1144
  repetition_penalty = repetition_penalty or 1.07
1145
  num_return_sequences = min(num_beams, num_return_sequences or 1)
1146
- do_sample = True if do_sample is None else do_sample
1147
  # doesn't include chat, instruction_nochat, iinput_nochat, added later
1148
  params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
1149
  early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
 
9
  import filelock
10
  import psutil
11
 
12
+ from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash
13
 
14
  SEED = 1236
15
  set_seed(SEED)
 
22
  import fire
23
  import torch
24
  from peft import PeftModel
25
+ from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
26
  from accelerate import init_empty_weights, infer_auto_device_map
27
 
28
  from prompter import Prompter
29
 
30
+ from finetune import get_loaders, example_data_points, generate_prompt, inv_prompt_type_to_model_lower
31
+ from stopping import get_stopping
32
 
33
  eval_extra_columns = ['prompt', 'response', 'score']
34
 
 
62
  local_files_only: bool = False,
63
  resume_download: bool = True,
64
  use_auth_token: Union[str, bool] = False,
65
+ trust_remote_code: Union[str, bool] = True,
66
 
67
  src_lang: str = "English",
68
  tgt_lang: str = "Russian",
 
125
  :param local_files_only: whether to only use local files instead of doing to HF for models
126
  :param resume_download: whether to resume downloads from HF for models
127
  :param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
128
+ :param trust_remote_code: whether to use trust any code needed for HF model
129
  :param src_lang: source languages to include if doing translation (None = all)
130
  :param tgt_lang: target languages to include if doing translation (None = all)
131
  :param gradio: whether to enable gradio, or to enable benchmark mode
 
170
 
171
  if is_public:
172
  input_lines = 1 # ensure set, for ease of use
173
+ temperature = 0.2 if temperature is None else temperature
174
+ top_p = 0.85 if top_p is None else top_p
175
+ top_k = 70 if top_k is None else top_k
176
+ if is_hf:
177
+ do_sample = True if do_sample is None else do_sample
178
+ else:
179
+ # by default don't sample, too chatty
180
+ do_sample = False if do_sample is None else do_sample
181
+
182
  if is_low_mem:
183
+ if not base_model:
184
+ base_model = 'h2oai/h2ogpt-oasst1-512-12b'
185
+ # don't set load_8bit if passed base_model, doesn't always work so can't just override
186
+ load_8bit = True
187
  else:
188
+ base_model = 'h2oai/h2ogpt-oasst1-512-20b' if not base_model else base_model
189
  if is_low_mem:
190
  load_8bit = True
191
  if is_hf:
 
238
  do_sample,
239
  )
240
 
241
+ locals_dict = locals()
242
+ locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
243
+ print(f"Generating model with params:\n{locals_print}", flush=True)
244
+ print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True)
245
+
246
  if not gradio:
247
  if eval_sharegpt_prompts_only > 0:
248
  # override default examples with shareGPT ones for human-level eval purposes only
 
430
 
431
  def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
432
  gpu_id=0,
433
+ use_auth_token=False,
434
+ trust_remote_code=True,
435
+ triton_attn=False,
436
+ long_sequence=True,
437
+ ):
438
  """
439
  Ensure model gets on correct device
440
  :param base_model:
 
444
  :param reward_type:
445
  :param gpu_id:
446
  :param use_auth_token:
447
+ :param trust_remote_code:
448
+ :param triton_attn:
449
+ :param long_sequence:
450
  :return:
451
  """
452
  with init_empty_weights():
453
  from transformers import AutoConfig
454
+ config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
455
+ trust_remote_code=trust_remote_code)
456
+ if triton_attn and 'mpt-' in base_model.lower():
457
+ config.attn_config['attn_impl'] = 'triton'
458
+ if long_sequence:
459
+ if 'mpt-7b-storywriter' in base_model.lower():
460
+ config.update({"max_seq_len": 83968})
461
+ if 'mosaicml/mpt-7b-chat' in base_model.lower():
462
+ config.update({"max_seq_len": 4096})
463
+ if issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())):
464
+ model = AutoModel.from_config(
465
+ config,
466
+ )
467
+ else:
468
+ # can't infer
469
+ model = None
470
+
471
+ if model is not None:
472
+ # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
473
+ # NOTE: Some models require avoiding sharding some layers,
474
+ # then would pass no_split_module_classes and give list of those layers.
475
+ device_map = infer_auto_device_map(
476
+ model,
477
  dtype=torch.float16 if load_half else torch.float32,
478
  )
479
+ if hasattr(model, 'model'):
480
+ device_map_model = infer_auto_device_map(
481
+ model.model,
482
+ dtype=torch.float16 if load_half else torch.float32,
483
+ )
484
+ device_map.update(device_map_model)
485
+ print('device_map: %s' % device_map, flush=True)
486
+ else:
487
+ device_map = "auto"
488
 
489
  n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
490
 
 
508
  if load_in_8bit or not load_half:
509
  model = model_loader.from_pretrained(
510
  base_model,
511
+ config=config,
512
  **model_kwargs,
513
  )
514
  else:
515
  model = model_loader.from_pretrained(
516
  base_model,
517
+ config=config,
518
  **model_kwargs,
519
  ).half()
520
  return model
 
533
  local_files_only: bool = False,
534
  resume_download: bool = True,
535
  use_auth_token: Union[str, bool] = False,
536
+ trust_remote_code: bool = True,
537
  compile: bool = True,
538
  **kwargs,
539
  ):
 
552
  :param local_files_only: use local files instead of from HF
553
  :param resume_download: resume downloads from HF
554
  :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
555
+ :param trust_remote_code: trust code needed by model
556
  :param compile: whether to compile torch model
557
  :param kwargs:
558
  :return:
 
571
  )
572
 
573
  from transformers import AutoConfig
574
+ config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
575
+ trust_remote_code=trust_remote_code)
576
  llama_type_from_config = 'llama' in str(config).lower()
577
  llama_type_from_name = "llama" in base_model.lower()
578
  llama_type = llama_type_from_config or llama_type_from_name
 
589
  local_files_only=local_files_only,
590
  resume_download=resume_download,
591
  use_auth_token=use_auth_token,
592
+ trust_remote_code=trust_remote_code,
593
  )
594
  else:
595
  tokenizer = tokenizer_loader
 
605
  model_kwargs = dict(local_files_only=local_files_only,
606
  torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
607
  resume_download=resume_download,
608
+ use_auth_token=use_auth_token,
609
+ trust_remote_code=trust_remote_code,
610
+ )
611
+ if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
612
  model_kwargs.update(dict(load_in_8bit=load_8bit,
613
  device_map={"": 0} if load_8bit and device == 'cuda' else "auto",
614
  ))
615
+ if 'mpt-' in base_model.lower() and gpu_id >= 0:
616
+ model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
617
+
618
  if 'OpenAssistant/reward-model'.lower() in base_model.lower():
619
+ # FIXME: could put on other GPUs
620
  model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
621
  model_kwargs.pop('torch_dtype', None)
622
 
 
624
  with torch.device(device):
625
  if infer_devices:
626
  model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
627
+ gpu_id=gpu_id,
628
+ use_auth_token=use_auth_token,
629
+ trust_remote_code=trust_remote_code,
630
+ )
631
  else:
632
  if load_half and not load_8bit:
633
  model = model_loader.from_pretrained(
 
649
  local_files_only=local_files_only,
650
  resume_download=resume_download,
651
  use_auth_token=use_auth_token,
652
+ trust_remote_code=trust_remote_code,
653
  device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, # seems to be required
654
  )
655
  else:
 
665
  local_files_only=local_files_only,
666
  resume_download=resume_download,
667
  use_auth_token=use_auth_token,
668
+ trust_remote_code=trust_remote_code,
669
  device_map="auto",
670
  )
671
  if load_half:
 
834
  if chat:
835
  # override, ignore user change
836
  num_return_sequences = 1
837
+ stopping_criteria = get_stopping(prompt_type, tokenizer, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838
  # help to avoid errors like:
839
  # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
840
  # RuntimeError: expected scalar type Half but found Float
 
913
  prompt = inputs_decoded
914
  elif inputs_decoded_raw == prompt:
915
  # some models specify special tokens that are part of normal prompt, so can't skip them
916
+ inputs_decoded = prompt = inputs_decoded_raw
917
+ decoder = decoder_raw
918
+ elif inputs_decoded_raw.replace("<unk> ", "").replace("<unk>", "").replace('\n', ' ').replace(' ', '') == prompt.replace('\n', ' ').replace(' ', ''):
919
+ inputs_decoded = prompt = inputs_decoded_raw
920
  decoder = decoder_raw
921
  else:
922
  print("WARNING: Special characters in prompt", flush=True)
 
1059
 
1060
  if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
1061
  prompt_type = inv_prompt_type_to_model_lower[model_lower]
1062
+ print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True)
1063
 
1064
  # examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
1065
  if show_examples is None:
 
1118
  placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
1119
  placeholder_input = ""
1120
  if model_lower:
1121
+ # default is plain, because might relly upon trust_remote_code to handle prompting
1122
+ prompt_type = prompt_type or 'plain'
1123
  else:
1124
  prompt_type = ''
1125
  examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
 
1148
  num_return_sequences = min(num_beams, num_return_sequences or 1)
1149
  do_sample = False if do_sample is None else do_sample
1150
  else:
1151
+ temperature = 0.1 if temperature is None else temperature
1152
+ top_p = 0.75 if top_p is None else top_p
1153
+ top_k = 40 if top_k is None else top_k
1154
  if chat:
1155
  num_beams = num_beams or 1
1156
  else:
 
1158
  max_new_tokens = max_new_tokens or 256
1159
  repetition_penalty = repetition_penalty or 1.07
1160
  num_return_sequences = min(num_beams, num_return_sequences or 1)
1161
+ do_sample = False if do_sample is None else do_sample
1162
  # doesn't include chat, instruction_nochat, iinput_nochat, added later
1163
  params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
1164
  early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
gradio_runner.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  import sys
6
 
7
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
 
8
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
9
  ping
10
  from finetune import prompt_type_to_model_name, prompt_types_strings, generate_prompt, inv_prompt_type_to_model_lower
@@ -49,6 +50,7 @@ def go_gradio(**kwargs):
49
  """
50
  else:
51
  description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio)<br>"
 
52
  description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
53
 
54
  if kwargs['verbose']:
@@ -389,6 +391,7 @@ def go_gradio(**kwargs):
389
  .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
390
 
391
  # Get inputs to evaluate()
 
392
  all_kwargs = kwargs.copy()
393
  all_kwargs.update(locals())
394
  inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
@@ -516,9 +519,12 @@ def go_gradio(**kwargs):
516
  :return:
517
  """
518
  args_list = list(args)
519
- user_message = args_list[0]
520
- input1 = args_list[1]
521
- context1 = args_list[2]
 
 
 
522
  if input1 and not user_message.endswith(':'):
523
  user_message1 = user_message + ":" + input1
524
  elif input1:
@@ -528,6 +534,8 @@ def go_gradio(**kwargs):
528
  if sanitize_user_prompt:
529
  from better_profanity import profanity
530
  user_message1 = profanity.censor(user_message1)
 
 
531
  if user_message1 in ['']:
532
  # e.g. when user just hits enter in textbox,
533
  # else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
@@ -559,7 +567,8 @@ def go_gradio(**kwargs):
559
  :param retry:
560
  :return:
561
  """
562
- args_list = copy.deepcopy(list(args))
 
563
  history = args_list[-1] # model_state is -2
564
  if retry and history:
565
  history.pop()
@@ -580,12 +589,18 @@ def go_gradio(**kwargs):
580
  context1 = ''
581
  for histi in range(len(history) - 1):
582
  data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
583
- context1 += generate_prompt(data_point, prompt_type1, chat1, reduced=True)[0].replace(
584
- '<br>', '\n')
585
- if not context1.endswith('\n'):
586
- context1 += '\n'
587
- if context1 and not context1.endswith('\n'):
588
- context1 += '\n' # ensure if terminates abruptly, then human continues on next line
 
 
 
 
 
 
589
  args_list[0] = instruction1 # override original instruction with history from user
590
  # only include desired chat history
591
  args_list[2] = context1[-kwargs['chat_history']:]
@@ -767,6 +782,7 @@ def go_gradio(**kwargs):
767
  lora_weights = no_lora_str
768
  return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
769
 
 
770
  all_kwargs1 = all_kwargs.copy()
771
  all_kwargs1['base_model'] = model_name.strip()
772
  all_kwargs1['load_8bit'] = load_8bit
 
5
  import sys
6
 
7
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
8
+ from prompter import Prompter
9
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
10
  ping
11
  from finetune import prompt_type_to_model_name, prompt_types_strings, generate_prompt, inv_prompt_type_to_model_lower
 
50
  """
51
  else:
52
  description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio)<br>"
53
+ description += "If this host is busy, try [gpt.h2o.ai 20B](https://gpt.h2o.ai) and [30B](http://gpu.hopto.org) and [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) and [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
54
  description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
55
 
56
  if kwargs['verbose']:
 
391
  .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
392
 
393
  # Get inputs to evaluate()
394
+ # don't deepcopy, can contain model itself
395
  all_kwargs = kwargs.copy()
396
  all_kwargs.update(locals())
397
  inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
 
519
  :return:
520
  """
521
  args_list = list(args)
522
+ user_message = args_list[eval_func_param_names.index('instruction')] # chat only
523
+ input1 = args_list[eval_func_param_names.index('iinput')] # chat only
524
+ context1 = args_list[eval_func_param_names.index('context')]
525
+ prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
526
+ chat1 = args_list[eval_func_param_names.index('chat')]
527
+ stream_output1 = args_list[eval_func_param_names.index('stream_output')]
528
  if input1 and not user_message.endswith(':'):
529
  user_message1 = user_message + ":" + input1
530
  elif input1:
 
534
  if sanitize_user_prompt:
535
  from better_profanity import profanity
536
  user_message1 = profanity.censor(user_message1)
537
+ # FIXME: WIP to use desired seperator when user enters nothing
538
+ prompter = Prompter(prompt_type1, debug=kwargs['debug'], chat=chat1, stream_output=stream_output1)
539
  if user_message1 in ['']:
540
  # e.g. when user just hits enter in textbox,
541
  # else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
 
567
  :param retry:
568
  :return:
569
  """
570
+ # don't deepcopy, can contain model itself
571
+ args_list = list(args).copy()
572
  history = args_list[-1] # model_state is -2
573
  if retry and history:
574
  history.pop()
 
589
  context1 = ''
590
  for histi in range(len(history) - 1):
591
  data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
592
+ prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, prompt_type1,
593
+ chat1, reduced=True)
594
+ # md -> back to text, maybe not super improtant if model trained enough
595
+ prompt = prompt.replace('<br>', chat_sep)
596
+ context1 += prompt
597
+ if not context1.endswith(chat_sep):
598
+ context1 += chat_sep
599
+
600
+ _, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
601
+ reduced=True)
602
+ if context1 and not context1.endswith(chat_sep):
603
+ context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
604
  args_list[0] = instruction1 # override original instruction with history from user
605
  # only include desired chat history
606
  args_list[2] = context1[-kwargs['chat_history']:]
 
782
  lora_weights = no_lora_str
783
  return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
784
 
785
+ # don't deepcopy, can contain model itself
786
  all_kwargs1 = all_kwargs.copy()
787
  all_kwargs1['base_model'] = model_name.strip()
788
  all_kwargs1['load_8bit'] = load_8bit
prompter.py CHANGED
@@ -6,7 +6,8 @@ class Prompter(object):
6
  allowed_repeat_line_length=10):
7
  self.prompt_type = prompt_type
8
  data_point = dict(instruction='', input='', output='')
9
- _, self.pre_response, self.terminate_response = generate_prompt(data_point, prompt_type, chat, False)
 
10
  self.debug = debug
11
  self.chat = chat
12
  self.stream_output = stream_output
@@ -15,7 +16,7 @@ class Prompter(object):
15
 
16
  def generate_prompt(self, data_point):
17
  reduced = False
18
- prompt, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
19
  if self.debug:
20
  print("prompt: ", prompt, flush=True)
21
  self.prompt = prompt
@@ -25,12 +26,12 @@ class Prompter(object):
25
  if isinstance(outputs, str):
26
  outputs = [outputs]
27
  if self.debug:
28
- print("output: ", '\n\n'.join(outputs), flush=True)
29
  if prompt is not None:
30
  self.prompt = prompt
31
 
32
  def clean_response(response):
33
- meaningless_words = ['<pad>', '</s>', '<|endoftext|>', '”\n']
34
  for word in meaningless_words:
35
  response = response.replace(word, "")
36
  if sanitize_bot_response:
@@ -103,5 +104,5 @@ class Prompter(object):
103
  # join all outputs, only one extra new line between outputs
104
  output = '\n'.join(outputs)
105
  if self.debug:
106
- print("outputclean: ", '\n\n'.join(outputs), flush=True)
107
  return output
 
6
  allowed_repeat_line_length=10):
7
  self.prompt_type = prompt_type
8
  data_point = dict(instruction='', input='', output='')
9
+ _, self.pre_response, self.terminate_response, self.chat_sep = \
10
+ generate_prompt(data_point, prompt_type, chat, False)
11
  self.debug = debug
12
  self.chat = chat
13
  self.stream_output = stream_output
 
16
 
17
  def generate_prompt(self, data_point):
18
  reduced = False
19
+ prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
20
  if self.debug:
21
  print("prompt: ", prompt, flush=True)
22
  self.prompt = prompt
 
26
  if isinstance(outputs, str):
27
  outputs = [outputs]
28
  if self.debug:
29
+ print("output:\n", '\n\n'.join(outputs), flush=True)
30
  if prompt is not None:
31
  self.prompt = prompt
32
 
33
  def clean_response(response):
34
+ meaningless_words = ['<pad>', '</s>', '<|endoftext|>']
35
  for word in meaningless_words:
36
  response = response.replace(word, "")
37
  if sanitize_bot_response:
 
104
  # join all outputs, only one extra new line between outputs
105
  output = '\n'.join(outputs)
106
  if self.debug:
107
+ print("outputclean:\n", '\n\n'.join(outputs), flush=True)
108
  return output
requirements.txt CHANGED
@@ -19,7 +19,7 @@ pandas==2.0.0
19
  matplotlib==3.7.1
20
  loralib==0.1.1
21
  bitsandbytes==0.38.1
22
- git+https://github.com/huggingface/peft.git@e8f66b8a425eced6c592089d40b8d33d82c2b2f0
23
  transformers==4.28.1
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
 
19
  matplotlib==3.7.1
20
  loralib==0.1.1
21
  bitsandbytes==0.38.1
22
+ git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
23
  transformers==4.28.1
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
stopping.py CHANGED
@@ -1,10 +1,5 @@
1
- import traceback
2
- from queue import Queue
3
- from threading import Thread
4
- import collections.abc
5
-
6
  import torch
7
- from transformers import StoppingCriteria
8
 
9
 
10
  class StoppingCriteriaSub(StoppingCriteria):
@@ -21,7 +16,55 @@ class StoppingCriteriaSub(StoppingCriteria):
21
  if torch.all((stop == input_ids[0][-len(stop):])).item():
22
  self.num_stops[stopi] += 1
23
  if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
 
24
  return True
25
  # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
26
  # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
27
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import StoppingCriteria, StoppingCriteriaList
3
 
4
 
5
  class StoppingCriteriaSub(StoppingCriteria):
 
16
  if torch.all((stop == input_ids[0][-len(stop):])).item():
17
  self.num_stops[stopi] += 1
18
  if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
19
+ # print("Stopped", flush=True)
20
  return True
21
  # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
22
  # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
23
  return False
24
+
25
+
26
+ def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"):
27
+ if prompt_type in ['human_bot', 'instruct_vicuna', 'instruct_with_end']:
28
+ if prompt_type == 'human_bot':
29
+ # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
30
+ # stopping only starts once output is beyond prompt
31
+ # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
32
+ stop_words = [human, bot, '\n' + human, '\n' + bot]
33
+ encounters = [1, 2]
34
+ elif prompt_type == 'instruct_vicuna':
35
+ # even below is not enough, generic strings and many ways to encode
36
+ stop_words = [
37
+ '### Human:',
38
+ """
39
+ ### Human:""",
40
+ """
41
+ ### Human:
42
+ """,
43
+ '### Assistant:',
44
+ """
45
+ ### Assistant:""",
46
+ """
47
+ ### Assistant:
48
+ """,
49
+ ]
50
+ encounters = [1, 2]
51
+ else:
52
+ # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
53
+ stop_words = ['### End']
54
+ encounters = [1]
55
+ stop_words_ids = [
56
+ tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
57
+ # handle single token case
58
+ stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
59
+ stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
60
+ # avoid padding in front of tokens
61
+ if tokenizer.pad_token:
62
+ stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
63
+ # handle fake \n added
64
+ stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
65
+ # build stopper
66
+ stopping_criteria = StoppingCriteriaList(
67
+ [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
68
+ else:
69
+ stopping_criteria = StoppingCriteriaList()
70
+ return stopping_criteria