|
from typing import Tuple |
|
|
|
import os |
|
import torch |
|
import time |
|
import json |
|
from pathlib import Path |
|
|
|
os.environ["BITSANDBYTES_NOWELCOME"] = "1" |
|
from llama import ModelArgs, Transformer, Tokenizer, LLaMA, default_quantize |
|
|
|
from google.cloud import storage |
|
|
|
bucket_name = os.environ.get("GCS_BUCKET") |
|
|
|
llama_weight_path = "weights/llama" |
|
tokenizer_weight_path = "weights/tokenizer" |
|
|
|
def download_pretrained_models( |
|
ckpt_path: str, |
|
tokenizer_path: str |
|
): |
|
print("creating local directories...") |
|
os.makedirs(llama_weight_path) |
|
os.makedirs(tokenizer_weight_path) |
|
|
|
print("initialize GCS client...") |
|
storage_client = storage.Client.create_anonymous_client() |
|
bucket = storage_client.bucket(bucket_name) |
|
|
|
print(f"download {ckpt_path} model weights...") |
|
blobs = bucket.list_blobs(prefix=f"{ckpt_path}/") |
|
for blob in blobs: |
|
filename = blob.name.split("/")[1] |
|
print(f"-{filename}") |
|
blob.download_to_filename(f"{llama_weight_path}/{filename}") |
|
|
|
print(f"download {tokenizer_path} tokenizer weights...") |
|
blobs = bucket.list_blobs(prefix=f"{tokenizer_path}/") |
|
for blob in blobs: |
|
filename = blob.name.split("/")[1] |
|
print(f"-{filename}") |
|
blob.download_to_filename(f"{tokenizer_weight_path}/{filename}") |
|
|
|
def get_pretrained_models( |
|
ckpt_path: str, |
|
tokenizer_path: str) -> LLaMA: |
|
|
|
download_pretrained_models(ckpt_path, tokenizer_path) |
|
|
|
generator = load( |
|
ckpt_dir=llama_weight_path, |
|
tokenizer_path=tokenizer_weight_path, |
|
max_seq_len=512, |
|
max_batch_size=1 |
|
) |
|
|
|
return generator |
|
|
|
def load( |
|
ckpt_dir: str, |
|
tokenizer_path: str, |
|
max_seq_len: int, |
|
max_batch_size: int, |
|
) -> LLaMA: |
|
start_time = time.time() |
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) |
|
print(checkpoints) |
|
|
|
with open(Path(ckpt_dir) / "params.json", "r") as f: |
|
params = json.loads(f.read()) |
|
|
|
model_args: ModelArgs = ModelArgs( |
|
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params |
|
) |
|
tokenizer = Tokenizer(model_path=f"{tokenizer_path}/tokenizer.model") |
|
model_args.vocab_size = tokenizer.n_words |
|
|
|
torch.set_default_tensor_type(torch.HalfTensor) |
|
print("Allocating transformer on host") |
|
ctx_tok = default_quantize.set(True) |
|
model = Transformer(model_args) |
|
default_quantize.set(ctx_tok) |
|
key_to_dim = { |
|
"w1": 0, |
|
"w2": -1, |
|
"w3": 0, |
|
"wo": -1, |
|
"wq": 0, |
|
"wk": 0, |
|
"wv": 0, |
|
"output": 0, |
|
"tok_embeddings": -1, |
|
"ffn_norm": None, |
|
"attention_norm": None, |
|
"norm": None, |
|
"rope": None, |
|
} |
|
|
|
|
|
torch.set_default_tensor_type(torch.FloatTensor) |
|
|
|
|
|
for i, ckpt in enumerate(checkpoints): |
|
print(f"Loading checkpoint {i}") |
|
checkpoint = torch.load(ckpt, map_location="cpu") |
|
for parameter_name, parameter in model.named_parameters(): |
|
short_name = parameter_name.split(".")[-2] |
|
if key_to_dim[short_name] is None and i == 0: |
|
parameter.data = checkpoint[parameter_name] |
|
elif key_to_dim[short_name] == 0: |
|
size = checkpoint[parameter_name].size(0) |
|
parameter.data[size * i : size * (i + 1), :] = checkpoint[ |
|
parameter_name |
|
] |
|
elif key_to_dim[short_name] == -1: |
|
size = checkpoint[parameter_name].size(-1) |
|
parameter.data[:, size * i : size * (i + 1)] = checkpoint[ |
|
parameter_name |
|
] |
|
del checkpoint[parameter_name] |
|
del checkpoint |
|
|
|
model.cuda() |
|
|
|
generator = LLaMA(model, tokenizer) |
|
print( |
|
f"Loaded in {time.time() - start_time:.2f} seconds with {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GiB" |
|
) |
|
return generator |
|
|
|
|
|
def get_output( |
|
generator: LLaMA, |
|
prompt: str, |
|
max_gen_len: int = 256, |
|
temperature: float = 0.8, |
|
top_p: float = 0.95): |
|
|
|
prompts = [prompt] |
|
results = generator.generate( |
|
prompts, |
|
max_gen_len=max_gen_len, |
|
temperature=temperature, |
|
top_p=top_p |
|
) |
|
|
|
return results |
|
|