File size: 5,647 Bytes
793da2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
from elm.model import *
from elm.utils import batchify
from transformers import AutoTokenizer
import json
def load_elm_model_and_tokenizer(local_path,
model_config_dict,
device="cuda",
load_partial=True,
get_num_layers_from_ckpt=True):
"""Load ELM model and tokenizer from local checkpoint."""
model_args = ModelArgs(**model_config_dict)
model = load_elm_model_from_ckpt(local_path, device=device, model_args=model_args, load_partial=load_partial, get_num_layers_from_ckpt=get_num_layers_from_ckpt)
tokenizer = AutoTokenizer.from_pretrained(local_path)
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"
return model, tokenizer
def generate_elm_response_given_model(prompts, model, tokenizer,
device="cuda",
max_ctx_word_len=1024,
max_ctx_token_len=0,
max_new_tokens=500,
temperature=0.8, # set to 0 for greedy decoding
top_k=200,
return_tok_cnt=False,
return_gen_only=False,
early_stop_on_eos=False):
"""Generate responses from ELM model given an input list of prompts ([str])."""
if max_ctx_token_len > 0:
inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_ctx_token_len).to(device)
else:
prompts = [" ".join(p.split(" ")[-max_ctx_word_len:]) for p in prompts]
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
results = []
input_tok_cnt = torch.numel(inputs.input_ids)
model.eval()
out_tok_cnt = 0
with torch.no_grad():
temperature = temperature
top_k = top_k
outputs = model.generate(inputs.input_ids, max_new_tokens, temperature=temperature, top_k=top_k,
return_gen_only=return_gen_only)
if return_tok_cnt:
out_tok_cnt += torch.numel(outputs)
if early_stop_on_eos:
mod_outputs = []
for i in range(len(outputs)):
curr_out = outputs[i]
eos_loc_id = -1
for j in range(len(outputs[i])):
tok_id = outputs[i][j]
if tok_id == tokenizer.eos_token_id:
eos_loc_id = j
break
if eos_loc_id >= 0:
curr_out = outputs[i][:eos_loc_id]
mod_outputs.append(curr_out)
outputs = mod_outputs
detokenized_output = tokenizer.batch_decode(outputs, skip_special_tokens=False)
results = detokenized_output
if return_tok_cnt:
return results, (input_tok_cnt, out_tok_cnt)
return results
def generate_elm_responses(elm_model_path,
prompts,
device=None,
elm_model_config={},
eval_batch_size=1,
verbose=True):
if not device:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Setting device to {device}")
model_config_dict = {
"hidden_size": elm_model_config.get("hidden_size", 2048),
"max_inp_len": elm_model_config.get("max_inp_len", 2048),
"num_attention_heads": elm_model_config.get("num_attention_heads", 32),
"num_layers": elm_model_config.get("num_layers", 48),
"bits": elm_model_config.get("bits", 256),
"vocab_size": elm_model_config.get("vocab_size", 50304),
"dropout": elm_model_config.get("dropout", 0.1),
"use_rotary_embeddings": elm_model_config.get("use_rotary_embeddings", True)
}
model, tokenizer = load_elm_model_and_tokenizer(local_path=elm_model_path, model_config_dict=model_config_dict, device=device, load_partial=True)
#prompts = [prompt if "[INST]" in prompt else f"[INST]{prompt}[/INST]" for prompt in prompts]
max_new_tokens = 128
if "classification" in elm_model_path or "detection" in elm_model_path:
max_new_tokens = 12
result = []
for prompt_batch in batchify(prompts, eval_batch_size):
responses, _ = generate_elm_response_given_model(prompt_batch,
model,
tokenizer,
device=device,
max_ctx_word_len=1024,
max_ctx_token_len=512,
max_new_tokens=max_new_tokens,
return_tok_cnt=True,
return_gen_only=False,
temperature=0.0,
early_stop_on_eos=True)
for prompt, response in zip(prompt_batch, responses):
response = response.split("[/INST]")[-1].strip()
result.append(response)
if verbose:
print(json.dumps({"prompt": prompt, "response": response}, indent=4))
print("\n***\n")
return result
|