seba's picture
Upload coreml_example.py
a36f620 verified
import time
import math
import numpy as np
from argparse import ArgumentParser
from transformers import AutoTokenizer
from dotenv import load_dotenv
import os
load_dotenv()
# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct", token=os.environ["HF_TOKEN"]
)
parser = ArgumentParser()
parser.add_argument("--model_path_emb", "--model-path-emb", required=True)
parser.add_argument("--model_path_mf", "--model-path-mf", required=True)
# parser.add_argument("--model_path_1", "--model-path-1", required=True)
# parser.add_argument("--model_path_40", "--model-path-40", required=True)
parser.add_argument("--model_path_head", "--model-path-head", required=True)
parser.add_argument("--prompt", "-p", required=True, type=str)
parser.add_argument("--max-tokens", "--max_tokens", type=int, default=100)
parser.add_argument("--min_p", "--min-p", type=float, default=0.3)
parser.add_argument("--temp", type=float, default=1.0)
args = parser.parse_args()
import coremltools as ct
print("Loading models...")
cu = ct.ComputeUnit.CPU_AND_NE
def load_model(path, fname=None):
if "mlmodelc" in path:
return ct.models.CompiledMLModel(path, cu, fname)
else:
return ct.models.MLModel(path, cu, function_name=fname)
emb_model = load_model(args.model_path_emb)
model_1 = load_model(args.model_path_mf, "length_1")
model_40 = load_model(args.model_path_mf, "length_40")
model_head = load_model(args.model_path_head)
# if args.model_path.rstrip("/").endswith(".mlpackage"):
# mf_model_1 = ct.models.MLModel(
# args.model_path,
# compute_units=ct.ComputeUnit.CPU_AND_NE,
# function_name="length_1",
# )
# mf_model_64 = ct.models.MLModel(
# args.model_path,
# compute_units=ct.ComputeUnit.CPU_AND_NE,
# function_name="length_64",
# )
# else:
# mf_model_1 = ct.models.CompiledMLModel(
# args.model_path,
# compute_units=ct.ComputeUnit.CPU_AND_NE,
# function_name="length_1",
# )
# mf_model_64 = ct.models.CompiledMLModel(
# args.model_path,
# compute_units=ct.ComputeUnit.CPU_AND_NE,
# function_name="length_64",
# )
# mf_model_emb = ct.models.MLModel(
# # args.model_path_emb,
# "./Llama-3.2-1B-EMB-16Bits.mlpackage",
# compute_units=ct.ComputeUnit.CPU_AND_NE,
# # function_name="length_64",
# )
# mf_model_mf = ct.models.MLModel(
# # args.model_path_1,
# "./Llama-3.2-1B-4bits-MF.mlpackage/",
# compute_units=ct.ComputeUnit.CPU_AND_NE,
# # function_name="length_64",
# )
# mf_model_40 = ct.models.MLModel(
# # args.model_path_40,
# "./Llama-3.2-1B-4bits-CTX-40.mlpackage",
# compute_units=ct.ComputeUnit.CPU_AND_NE,
# # function_name="length_64",
# )
# head = ct.models.MLModel(
# # args.model_path_head,
# "./Llama-3.2-1B-HEAD-6Bits.mlpackage",
# compute_units=ct.ComputeUnit.CPU_AND_NE,
# # function_name="length_64",
# )
def save_compiled(model):
from shutil import copytree
compiled_model_path = model.get_compiled_model_path()
copytree(
compiled_model_path,
model.package_path.replace(".mlpackage", ".mlmodelc"),
dirs_exist_ok=True,
)
def min_p_sample(logits, min_p, temp):
# logits = logits.astype(np.float16)
max_ = np.max(logits * (1 / temp), axis=1, keepdims=True)
logits = logits - max_
logits = np.exp(logits)
logits[logits < min_p] = 0
# logits = logits.astype(np.float32)
logits = np.cumsum(logits, axis=1)
sample = np.random.uniform(high=logits[:, -1:])
sample = np.argmax(logits > sample, axis=1).astype(np.int32)
return sample
def build_causal_mask(seq_length, start, size, end):
mask = np.full((1, 1, size, seq_length), np.array(-np.inf, dtype=np.float16))
i, h, j, k = np.indices(mask.shape)
mask[((k <= (j + start)) & (j < end)) | ((j >= end) & (k == 0))] = (
0 # fill first columns with ones to prevent softmax division by 0
)
return mask
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
mask = build_causal_mask(512, 0, 512, 512)
max_length = 40
# length = len(tokenizer(args.prompt)["input_ids"])
prompt = [{"role": "user", "content": args.prompt}]
length = len(tokenizer.apply_chat_template(prompt, add_generation_prompt=True))
print("Prompt length:", length)
input_ids = tokenizer.apply_chat_template(
prompt,
return_tensors="np",
padding=True,
# max_length=max_length,
return_dict=True,
add_generation_prompt=True,
tokenizer_kwargs={
# "padding": True,
"pad_to_multiple_of": max_length,
},
)["input_ids"].astype(np.int32)
# input_ids = tokenizer(
# args.prompt,
# return_tensors="np",
# padding="max_length",
# max_length=max_length,
# )["input_ids"].astype(np.int32)
print("Prompt:\n", tokenizer.decode(input_ids[0]))
state = model_40.make_state()
start = time.time()
for i in range(math.ceil(length / max_length)):
input_embs = emb_model.predict(
{"input_ids": input_ids[:, i * max_length : (i + 1) * max_length]}
)["input_embeddings_channels_first"].astype(np.float16)
pred = model_40.predict(
{
"input_ids": input_embs,
"query_pos1": np.array([i * max_length], dtype=np.int32),
"mask": mask[:, :, i * max_length : (i + 1) * max_length],
# "indices": np.array([0], dtype=np.int32),
"indices": np.arange(i * max_length, (i + 1) * max_length, dtype=np.int32),
},
state,
)
prompt_time = time.time() - start
pred = model_head.predict(
{"hidden_states": pred["final_norm_rmsnorm"][..., [length % max_length - 1]].astype(np.float16)}
)
# input_ids = pred["logits"][..., length - 1].argmax(1, keepdims=True).astype(np.int32)
# logits = pred["logits"][..., [length - 1]]
logits = pred["concat_0"]
input_ids = min_p_sample(logits, args.min_p, args.temp)
print("Generated:")
print(tokenizer.decode(input_ids[0]), end="", flush=True)
start = time.time()
for i in range(args.max_tokens):
input_embs = emb_model.predict({"input_ids": input_ids})[
"input_embeddings_channels_first"
].astype(np.float16)
pred = model_1.predict(
{
"input_ids": input_embs,
"query_pos1": np.array([i + length], dtype=np.int32),
"mask": mask[:, :, [i + length]],
"indices": np.array([i + length], dtype=np.int32),
},
state,
)
pred = model_head.predict(
{"hidden_states": pred["final_norm_rmsnorm"].astype(np.float16)}
)
# input_ids = min_p_sample(pred["logits"], args.min_p, args.temp)
input_ids = min_p_sample(pred["concat_0"], args.min_p, args.temp)
# input_ids = pred["logits"].argmax(1).astype(np.int32)
print(tokenizer.decode(input_ids[0]), end="", flush=True)
print("", "=" * 10)
generation_time = time.time() - start
print(
"Prompt:",
length / prompt_time,
"tokens-per-sec",
f"({math.ceil(length / max_length) * max_length / prompt_time} considering the processed padding)",
)
print("Generation:", args.max_tokens / generation_time, "tokens-per-sec")