test_model / howto /tpus.md
khoicrtp's picture
init
12001a9

TPU support

Lit-LLaMA used lightning.Fabric under the hood, which itself supports TPUs (via PyTorch XLA).

The following commands will allow you to set up a Google Cloud instance with a TPU v4 VM:

gcloud compute tpus tpu-vm create lit-llama --version=tpu-vm-v4-pt-2.0 --accelerator-type=v4-8 --zone=us-central2-b
gcloud compute tpus tpu-vm ssh lit-llama --zone=us-central2-b

Now that you are in the machine, let's clone the repository and install the dependencies

git clone https://github.com/Lightning-AI/lit-llama
cd lit-llama
pip install -r requirements.txt

By default, computations will run using the new (and experimental) PjRT runtime. Still, it's recommended that you set the following environment variables

export PJRT_DEVICE=TPU
export ALLOW_MULTIPLE_LIBTPU_LOAD=1

Note You can find an extensive guide on how to get set-up and all the available options here.

Since you created a new machine, you'll probably need to download the weights. You could scp them into the machine with gcloud compute tpus tpu-vm scp or you can follow the steps described in our downloading guide.

Inference

Generation works out-of-the-box with TPUs:

python3 generate.py --prompt "Hello, my name is" --num_samples 2

This command will take a long time as XLA needs to compile the graph (13 min) before running the model. In fact, you'll notice that the second sample takes considerable less time (12 sec).

Finetuning

Coming soon.

Warning When you are done, remember to delete your instance

gcloud compute tpus tpu-vm delete lit-llama --zone=us-central2-b