""" https://github.com/allenai/open-instruct """ import torch import tqdm from transformers import StoppingCriteria, StoppingCriteriaList class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords_str, tokenizer): StoppingCriteria.__init__(self) self.current_context = [] self.tokenizer = tokenizer self.keywords_str = keywords_str def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if len(self.current_context) == 0: self.current_context = [[] for _ in range(input_ids.shape[0])] # self.current_context.append(input_ids[0][-1].item()) sequences_should_be_stopped = [] for i in range(input_ids.shape[0]): _id = input_ids[i][-1].item() self.current_context[i].append(_id) current_context = self.tokenizer.decode(self.current_context[i]) should_be_stopped = False for word in self.keywords_str: if word in current_context: should_be_stopped = True break sequences_should_be_stopped.append(should_be_stopped) return all(sequences_should_be_stopped) class KeyWordsCriteriaTrunc(StoppingCriteria): def __init__(self, stop_id_sequences, prompt_length): assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids" self.stop_sequences = stop_id_sequences self.prompt_length = prompt_length def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: sequences_should_be_stopped = [] for i in range(input_ids.shape[0]): ids = input_ids[i][self.prompt_length:].tolist() should_be_stopped = False for stop_sequence in self.stop_sequences: if input_ids.shape[0] == 1: _ids = ids[-len(stop_sequence):] else: _ids = ids for j in range(len(_ids), 0, -len(stop_sequence)): if _ids[max(j - len(stop_sequence), 0): j] == stop_sequence: should_be_stopped = True break if should_be_stopped: break sequences_should_be_stopped.append(should_be_stopped) return all(sequences_should_be_stopped) class KeyWordsCriteria(StoppingCriteria): def __init__(self, stop_id_sequences): assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids" self.stop_sequences = stop_id_sequences def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: sequences_should_be_stopped = [] for i in range(input_ids.shape[0]): sequence_should_be_stopped = False for stop_sequence in self.stop_sequences: if input_ids[i][-len(stop_sequence):].tolist() == stop_sequence: sequence_should_be_stopped = True break sequences_should_be_stopped.append(sequence_should_be_stopped) return all(sequences_should_be_stopped) @torch.no_grad() def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, add_special_tokens=True, disable_tqdm=False, **generation_kwargs): generations = [] if not disable_tqdm: progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions") num_return_sequences = generation_kwargs.get("num_return_sequences", 1) for i in range(0, len(prompts), batch_size): batch_prompts = prompts[i:i+batch_size] tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens) batch_input_ids = tokenized_prompts.input_ids attention_mask = tokenized_prompts.attention_mask if model.device.type == "cuda": batch_input_ids = batch_input_ids.cuda() attention_mask = attention_mask.cuda() # try: stop_criteria = KeywordsStoppingCriteria(stop_id_sequences, tokenizer) batch_outputs = model.generate( input_ids=batch_input_ids, attention_mask=attention_mask, stopping_criteria=StoppingCriteriaList([stop_criteria]), # stopping_criteria=[KeyWordsCriteria(stop_id_sequences)] if stop_id_sequences else None, # stopping_criteria=[KeyWordsCriteriaTrunc(stop_id_sequences, batch_input_ids.size(1))] if stop_id_sequences else None, **generation_kwargs ) # the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate. # so some outputs still have the stop sequence, which we need to remove. # if stop_id_sequences: # for output_idx in range(batch_outputs.shape[0]): # for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]): # if any(batch_outputs[output_idx, token_idx: token_idx+len(stop_sequence)].tolist() == stop_sequence for stop_sequence in stop_id_sequences): # batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id # break # remove the prompt from the output # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs. # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token. # space is important for some tasks (e.g., code completion). batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True) batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True) # duplicate the prompts to match the number of return sequences batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)] batch_generations = [ output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs) ] # remove the remain stop sequence from the output. for idx, prediction in enumerate(batch_generations): for stop_sequence in stop_id_sequences: batch_generations[idx] = prediction.split(stop_sequence)[0] generations += batch_generations if not disable_tqdm: progress.update(len(batch_prompts)//num_return_sequences) assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences" return generations def load_hf_lm_and_tokenizer( model_name_or_path, tokenizer_name_or_path=None, device_map="auto", load_in_8bit=False, load_in_half=True, gptq_model=False, use_fast_tokenizer=False, padding_side="left", use_safetensors=False, ): import torch from transformers import AutoModelForCausalLM, AutoTokenizer if not tokenizer_name_or_path: tokenizer_name_or_path = model_name_or_path tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=use_fast_tokenizer, padding_side=padding_side, trust_remote_code=True) # tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, legacy=False, use_fast=use_fast_tokenizer, padding_side=padding_side, trust_remote_code=True) # set pad token to eos token if pad token is not set if tokenizer.pad_token is None: if tokenizer.unk_token: tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token_id = tokenizer.unk_token_id elif tokenizer.eos_token: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id else: raise ValueError("You are using a new tokenizer without a pad token." "This is not supported by this script.") # if tokenizer.pad_token is None: # tokenizer.pad_token = tokenizer.unk_token # tokenizer.pad_token_id = tokenizer.unk_token_id if gptq_model: from auto_gptq import AutoGPTQForCausalLM model_wrapper = AutoGPTQForCausalLM.from_quantized( model_name_or_path, device="cuda:0", use_triton=True ) model = model_wrapper.model elif load_in_8bit: model = AutoModelForCausalLM.from_pretrained( model_name_or_path, device_map=device_map, load_in_8bit=True ) else: # return "", tokenizer # defaul load in float16 model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map=device_map, trust_remote_code=True, use_safetensors=use_safetensors) if torch.cuda.is_available(): model = model.cuda() if load_in_half: model = model.half() model.eval() return model, tokenizer def _test_generate_completions(): model_name_or_path = "../models/codellama_7b/v1-16k" llm, tokenizer = load_hf_lm_and_tokenizer( model_name_or_path=model_name_or_path, load_in_half=True, use_fast_tokenizer=True, use_safetensors=True, ) # some math word problems prompts = [ "---\n1+1=2\n---2+2=4\n---3+3=6\n---4+4=8\n---5+5=10\n---6+6=", "---\n1+1=2\n---12+12=24\n---3+3=6\n---12345+12345=", # "A train leaves Chicago at 7am and travels at 60mph. Another train leaves Chicago at 9am and travels at 80mph. When will the second train overtake the first?", # "The sum of two numbers is 10. The difference of the same two numbers is 4. What are the two numbers?", ] stop_sequences = ["\n\n\n", "---"] # Because many tokenizers will treat the word after space differently from the original word alone, # to be consistent, we add a space before tokenization and remove it after tokenization. # stop_id_sequences = [tokenizer.encode(" " + x, add_special_tokens=False)[1:] for x in stop_sequences] outputs = generate_completions( model=llm, tokenizer=tokenizer, prompts=prompts, max_new_tokens=128, batch_size=16, # stop_id_sequences=stop_id_sequences, stop_id_sequences=stop_sequences, ) print(outputs) if __name__ == "__main__": _test_generate_completions()