|
import time |
|
import math |
|
import numpy as np |
|
from argparse import ArgumentParser |
|
from transformers import AutoTokenizer |
|
from dotenv import load_dotenv |
|
import os |
|
|
|
load_dotenv() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"meta-llama/Llama-3.2-1B-Instruct", token=os.environ["HF_TOKEN"] |
|
) |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument("--model_path_emb", "--model-path-emb", required=True) |
|
parser.add_argument("--model_path_mf", "--model-path-mf", required=True) |
|
|
|
|
|
parser.add_argument("--model_path_head", "--model-path-head", required=True) |
|
parser.add_argument("--prompt", "-p", required=True, type=str) |
|
parser.add_argument("--max-tokens", "--max_tokens", type=int, default=100) |
|
parser.add_argument("--min_p", "--min-p", type=float, default=0.3) |
|
parser.add_argument("--temp", type=float, default=1.0) |
|
args = parser.parse_args() |
|
|
|
import coremltools as ct |
|
|
|
print("Loading models...") |
|
|
|
cu = ct.ComputeUnit.CPU_AND_NE |
|
|
|
|
|
def load_model(path, fname=None): |
|
if "mlmodelc" in path: |
|
return ct.models.CompiledMLModel(path, cu, fname) |
|
else: |
|
return ct.models.MLModel(path, cu, function_name=fname) |
|
|
|
|
|
emb_model = load_model(args.model_path_emb) |
|
model_1 = load_model(args.model_path_mf, "length_1") |
|
model_40 = load_model(args.model_path_mf, "length_40") |
|
model_head = load_model(args.model_path_head) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_compiled(model): |
|
from shutil import copytree |
|
|
|
compiled_model_path = model.get_compiled_model_path() |
|
copytree( |
|
compiled_model_path, |
|
model.package_path.replace(".mlpackage", ".mlmodelc"), |
|
dirs_exist_ok=True, |
|
) |
|
|
|
|
|
def min_p_sample(logits, min_p, temp): |
|
|
|
max_ = np.max(logits * (1 / temp), axis=1, keepdims=True) |
|
logits = logits - max_ |
|
logits = np.exp(logits) |
|
logits[logits < min_p] = 0 |
|
|
|
logits = np.cumsum(logits, axis=1) |
|
sample = np.random.uniform(high=logits[:, -1:]) |
|
sample = np.argmax(logits > sample, axis=1).astype(np.int32) |
|
return sample |
|
|
|
|
|
def build_causal_mask(seq_length, start, size, end): |
|
mask = np.full((1, 1, size, seq_length), np.array(-np.inf, dtype=np.float16)) |
|
i, h, j, k = np.indices(mask.shape) |
|
mask[((k <= (j + start)) & (j < end)) | ((j >= end) & (k == 0))] = ( |
|
0 |
|
) |
|
return mask |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
mask = build_causal_mask(512, 0, 512, 512) |
|
|
|
max_length = 40 |
|
|
|
prompt = [{"role": "user", "content": args.prompt}] |
|
length = len(tokenizer.apply_chat_template(prompt, add_generation_prompt=True)) |
|
print("Prompt length:", length) |
|
input_ids = tokenizer.apply_chat_template( |
|
prompt, |
|
return_tensors="np", |
|
padding=True, |
|
|
|
return_dict=True, |
|
add_generation_prompt=True, |
|
tokenizer_kwargs={ |
|
|
|
"pad_to_multiple_of": max_length, |
|
}, |
|
)["input_ids"].astype(np.int32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Prompt:\n", tokenizer.decode(input_ids[0])) |
|
state = model_40.make_state() |
|
start = time.time() |
|
for i in range(math.ceil(length / max_length)): |
|
input_embs = emb_model.predict( |
|
{"input_ids": input_ids[:, i * max_length : (i + 1) * max_length]} |
|
)["input_embeddings_channels_first"].astype(np.float16) |
|
pred = model_40.predict( |
|
{ |
|
"input_ids": input_embs, |
|
"query_pos1": np.array([i * max_length], dtype=np.int32), |
|
"mask": mask[:, :, i * max_length : (i + 1) * max_length], |
|
|
|
"indices": np.arange(i * max_length, (i + 1) * max_length, dtype=np.int32), |
|
}, |
|
state, |
|
) |
|
prompt_time = time.time() - start |
|
pred = model_head.predict( |
|
{"hidden_states": pred["final_norm_rmsnorm"][..., [length % max_length - 1]].astype(np.float16)} |
|
) |
|
|
|
|
|
logits = pred["concat_0"] |
|
input_ids = min_p_sample(logits, args.min_p, args.temp) |
|
print("Generated:") |
|
print(tokenizer.decode(input_ids[0]), end="", flush=True) |
|
start = time.time() |
|
for i in range(args.max_tokens): |
|
input_embs = emb_model.predict({"input_ids": input_ids})[ |
|
"input_embeddings_channels_first" |
|
].astype(np.float16) |
|
pred = model_1.predict( |
|
{ |
|
"input_ids": input_embs, |
|
"query_pos1": np.array([i + length], dtype=np.int32), |
|
"mask": mask[:, :, [i + length]], |
|
"indices": np.array([i + length], dtype=np.int32), |
|
}, |
|
state, |
|
) |
|
pred = model_head.predict( |
|
{"hidden_states": pred["final_norm_rmsnorm"].astype(np.float16)} |
|
) |
|
|
|
input_ids = min_p_sample(pred["concat_0"], args.min_p, args.temp) |
|
|
|
print(tokenizer.decode(input_ids[0]), end="", flush=True) |
|
print("", "=" * 10) |
|
generation_time = time.time() - start |
|
|
|
print( |
|
"Prompt:", |
|
length / prompt_time, |
|
"tokens-per-sec", |
|
f"({math.ceil(length / max_length) * max_length / prompt_time} considering the processed padding)", |
|
) |
|
print("Generation:", args.max_tokens / generation_time, "tokens-per-sec") |
|
|