|
|
|
|
|
import math |
|
import sys |
|
import time |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import lightning as L |
|
import torch |
|
import tqdm |
|
|
|
from lit_llama import LLaMA, Tokenizer |
|
from lit_llama.utils import EmptyInitOnDevice, llama_model_lookup |
|
|
|
from datasets import load_dataset |
|
|
|
|
|
def load_eval_data(dataset_name: str) -> str: |
|
|
|
if dataset_name == "wikitext": |
|
|
|
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") |
|
testdata = "\n\n".join(testdata["text"]) |
|
elif dataset_name == "ptb": |
|
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") |
|
testdata = "\n\n".join(testdata["sentence"]) |
|
elif dataset_name == "c4": |
|
testdata = load_dataset( |
|
"allenai/c4", |
|
"allenai--c4", |
|
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, |
|
split="validation", |
|
) |
|
testdata = " ".join(testdata[:1100]["text"]) |
|
|
|
else: |
|
raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)") |
|
return testdata |
|
|
|
|
|
def main( |
|
datasets: str = "wikitext,ptb,c4", |
|
*, |
|
|
|
|
|
accelerator: str = "auto", |
|
checkpoint_path: Optional[Path] = None, |
|
tokenizer_path: Optional[Path] = None, |
|
dtype: str = "float32", |
|
quantize: Optional[str] = None, |
|
) -> None: |
|
"""Generates text samples based on a pre-trained LLaMA model and tokenizer. |
|
|
|
Args: |
|
datasets: The datasets to use as a comma separated string |
|
# compile: Whether to compile the model. |
|
accelerator: The hardware to run on. Possible choices are: |
|
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. |
|
checkpoint_path: The checkpoint path to load. |
|
tokenizer_path: The tokenizer path to load. |
|
quantize: Whether to quantize the model and using which method: |
|
``"llm.int8"``: LLM.int8() mode, |
|
``"gptq.int4"``: GPTQ 4-bit mode. |
|
""" |
|
if not checkpoint_path: |
|
checkpoint_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth") |
|
if not tokenizer_path: |
|
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model") |
|
assert checkpoint_path.is_file() |
|
assert tokenizer_path.is_file() |
|
|
|
fabric = L.Fabric(accelerator=accelerator, devices=1) |
|
|
|
dt = getattr(torch, dtype, None) |
|
if not isinstance(dt, torch.dtype): |
|
raise ValueError(f"{dtype} is not a valid dtype.") |
|
dtype = dt |
|
|
|
with EmptyInitOnDevice( |
|
device=fabric.device, dtype=dtype, quantization_mode=quantize |
|
): |
|
print("Loading model ...", file=sys.stderr) |
|
t0 = time.time() |
|
checkpoint = torch.load(checkpoint_path) |
|
name = llama_model_lookup(checkpoint) |
|
model = LLaMA.from_name(name) |
|
model.load_state_dict(checkpoint) |
|
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) |
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
total_toks = 0 |
|
model = fabric.setup_module(model) |
|
|
|
tokenizer = Tokenizer(tokenizer_path) |
|
|
|
for dsname in datasets.split(","): |
|
test_string = load_eval_data(dsname) |
|
encoded_text = tokenizer.encode( |
|
test_string, bos=True, eos=False, device=fabric.device |
|
) |
|
encoded_text = encoded_text[ |
|
None, : 256 * model.config.block_size |
|
] |
|
t0 = time.perf_counter() |
|
|
|
nlls = 0 |
|
toks = 0 |
|
with torch.inference_mode(): |
|
block_size = 2048 |
|
for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)): |
|
inp = encoded_text[:, i : i + block_size] |
|
logits = model(inp)[0] |
|
nll = torch.nn.functional.cross_entropy( |
|
logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum" |
|
) |
|
toks += inp.size(1) - 1 |
|
nlls += nll.item() |
|
|
|
print(encoded_text.shape, logits.shape) |
|
encoded_text = encoded_text[:, : logits.shape[0]] |
|
ppl = math.exp(nlls / toks) |
|
print(f"Perplexity on {dsname}: {ppl:.2f}") |
|
total_toks += toks |
|
|
|
t = time.perf_counter() - t0 |
|
print( |
|
f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec", |
|
file=sys.stderr, |
|
) |
|
print( |
|
f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", |
|
file=sys.stderr, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
from jsonargparse import CLI |
|
|
|
torch.set_float32_matmul_precision("high") |
|
CLI(main) |
|
|