In [1]:
# step 0. Preliminary
import torch
# check that cuda doesn't crash on us
print(torch.cuda.get_device_name())
# check that transformers installed
import transformers

NVIDIA A100-PCIE-40GB


In [2]:
EPOCHS=3

In [3]:
# Step 1. Preparing the training
# First ensure that required files are here
from pathlib import Path
assert Path("TinyStoriesV2-GPT4-train.txt").exists()
assert Path("TinyStoriesV2-GPT4-valid.txt").exists()

In [4]:
# Then prepare directories
Path("chunks.txt/train").mkdir(parents=True, exist_ok=True)
Path("chunks.tensors/train").mkdir(parents=True, exist_ok=True)
Path("chunks.txt/valid").mkdir(parents=True, exist_ok=True)
Path("chunks.tensors/valid").mkdir(parents=True, exist_ok=True)

In [5]:
# Then prepare method to split one text to several
from multiprocessing.pool import Pool
from tqdm.contrib.concurrent import process_map
import os
_chunk_me = None
def extract_chunk(chunk):
    split, i, chunk_from, chunk_to = chunk
    chunk = _chunk_me[chunk_from:chunk_to].strip()    
    name = f"chunks.txt/{split}/chunk-{i+1}.txt"
    with open(name, "w") as f:
        f.write(chunk)
    return name

def split_to_text_chunks(split:str, chunk_size = 16*1024*1024, max_workers=None):
    global _chunk_me #text is too chunky to pass as argument. storing as global so fork() can take care of it
    print(f"reading {split}")
    text = _chunk_me = Path(f"./TinyStoriesV2-GPT4-{split}.txt").read_text()
    offsets = []    
    delimiter = "<|endoftext|>"
    i=0
    while i < len(text):    
        offsets.append(i)
        i += chunk_size
        i = text.find(delimiter, i)
        if i < 0:
            break
        i += len(delimiter)
    offsets.append(len(text))
    chunks = [(split, i, start,end) for (i, (start, end)) in enumerate(zip(offsets[:-1], offsets[1:]))]
    
    print("writing")
    process_map(extract_chunk, chunks, max_workers=max_workers)
    

In [7]:
# Prepare text of train split
if not Path("chunks.txt/train/chunk-133.txt").exists():
    split_to_text_chunks("train")
else:
    print("Assuming split has finished already")

Assuming split has finished already


In [9]:
# Prepare text of valid split
if not Path("chunks.txt/valid/chunk-2.txt").exists():
    split_to_text_chunks("valid")    
else:
    print("Assuming split has finished already")

Assuming split has finished already


In [10]:
# Step 2. Prepare OpenLLAMA tokenizer. 
#Needed to be done once(TODO: add code to load tokenizer?)
from transformers import AutoTokenizer
import os
if not Path('tokenizer.json').exists():    
    try:
        tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_3b")
        tokenizer.save_pretrained(".")
    except Exception as e:
        print(e)
        os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"]="python"    
        tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_3b")
        tokenizer.save_pretrained(".")
        del os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"]
tokenizer = AutoTokenizer.from_pretrained(".")

In [11]:
# Step 3. Preparing to tokenize each text chunk
from tqdm.contrib.concurrent import process_map
def tokenize_file(filename:Path):
    text = Path.read_text(filename)
    stories = text.split("<|endoftext|>")
    result = []
    while stories:
        story = stories.pop(0).strip()
        tokenized = tokenizer(story, max_length=None).input_ids
        tokenized.append(tokenizer.eos_token_id)
        result.append(torch.tensor(tokenized))
    output_name = str(filename).replace(".txt", ".tensors")
    torch.save(result, output_name)

def tokenize_split(split, max_workers=None):
    to_process = list(Path(f"chunks.txt/{split}").glob("*"))    
    process_map(tokenize_file, to_process, max_workers=max_workers)
        

In [12]:
# processing train(this can take several minutes)
if not Path("chunks.tensors/train/chunk-133.tensors").exists():
    tokenize_split("train")
else:
    print("Assuming train was tokenized already")

Assuming train was tokenized already


In [13]:
# processing valid(this can take one minutes)
if not Path("chunks.tensors/valid/chunk-2.tensors").exists():
    tokenize_split("valid")
else:
    print("Assuming valid was tokenized already")

Assuming valid was tokenized already


In [14]:
# Step 4. Training. 
# Step 4.1 Preparing tokenizer and setting pad token if it is not set(it is not set)
tokenizer = AutoTokenizer.from_pretrained(".")
if not tokenizer.pad_token_id:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    print("Resetting [PAD] to [EOS]")

Resetting [PAD] to [EOS]


In [18]:
# Step 4.2. Preparing model
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaForCausalLM

tiny_llama = LlamaConfig(
    hidden_size=64, 
    vocab_size=tokenizer.vocab_size,
    intermediate_size=256, 
    num_attention_heads=16, 
    num_hidden_layers=8)

torch.manual_seed(11010)
torch.cuda.manual_seed(11010)
model = LlamaForCausalLM(tiny_llama).cuda().bfloat16()

In [16]:
import functools
import torch.nn.functional as F
from tqdm.contrib.concurrent import process_map
from tqdm.auto import tqdm

# Step 4.3 Preparing dataset class
def get_file_data_len(filename):
    data = torch.load(filename)
    return (filename, len(data))
from datasets import Dataset

CACHE_SIZE = 2000 # There are ~150 train splits. We can fit them in memory, so let's do it

class TinyDataset(Dataset):
    def __init__(self, split: str, populate_cache=True):
        print(f"Reading dataset {split} data")
        self.file_lens = process_map(
            get_file_data_len,
            list(Path(f"chunks.tensors/{split}").glob("*")))
        self.file_lens.sort()
        if populate_cache:
            print("Populating a cache")
            for filename, _ in tqdm(self.file_lens):
                self.load_tensor_file(filename)

    @functools.lru_cache(maxsize=CACHE_SIZE)
    def load_tensor_file(self, filename):
        return torch.load(filename)

    def __len__(self):
        return sum(x[1] for x in self.file_lens)

    def global_index_to_local(self, i):
        for (file, length) in self.file_lens:
            if i < length:
                return (file, i)
            i -= length
        raise IndexError(f"{i} is out-of-bonds, have {len(self)} sample")

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        if isinstance(index, int):
            filename, local_index = self.global_index_to_local(index)
            tensors = self.load_tensor_file(filename)
            return {
                'input_ids': tensors[local_index]
            }
        if isinstance(index, list):
            data = []
            indices = index
            for index in indices:
                filename, local_index = self.global_index_to_local(index)
                tensors = self.load_tensor_file(filename)
                data.append(tensors[local_index])

            return {'input_ids': data}

        raise TypeError(f'Invaldi index type {type(index)}')
        
def batch_collate(data: list[torch.Tensor]):
    max_len = max(len(datum["input_ids"]) for datum in data)
    inputs = []
    attentions = []
    for row in data:
        input_ids = row["input_ids"]
        attention_mask = torch.ones_like(input_ids)
        attention_mask[-1] = 0  # don't care about EOS
        # Manual padding
        to_pad = max_len - len(input_ids)
        is_left_pad = tokenizer.padding_side == "left"
        padding = (is_left_pad * to_pad, (1 - is_left_pad) * to_pad)
        input_ids = F.pad(input_ids, padding, value=tokenizer.pad_token_id)
        attention_mask = F.pad(attention_mask, padding, value=0)
        inputs.append(input_ids)
        attentions.append(attention_mask)

    attention_masks = torch.stack(attentions)
    input_ids = torch.stack(inputs)
    labels = input_ids.clone()

    # disable prediction of the padding
    labels[attention_masks == 0] = -100
    # enable prediction of an actual EOS
    labels[:, -1] = tokenizer.eos_token_id

    return {
        'input_ids': input_ids,
        'attention_mask': attention_masks,
        'labels': labels
    }

def get_max_story_length(ds):    
    return max(file_len[1] for file_len in ds.file_lens)


In [17]:
assert tokenizer.padding_side in ["left", "right"]
train_ds = TinyDataset("train")
assert get_max_story_length(train_ds) <= tokenizer.model_max_length, "WARNIING: split long stories"

Reading dataset train data


  0%|          | 0/133 [00:00<?, ?it/s]

Populating a cache


  0%|          | 0/133 [00:00<?, ?it/s]

AssertionError: WARNIING: split long stories

In [19]:
from torch.utils.data import DataLoader
torch.manual_seed(11010)
torch.cuda.manual_seed(11010)
train_dl = DataLoader(train_ds, 16, True, collate_fn=batch_collate)

In [20]:
# prepare wandb
import wandb
wandb.init(
    project="training-tiny-llama-preview",
    config={
    "architecture": "llama",
    "dataset": "tiny-stories",
    "epochs": EPOCHS,
    }   
)

[34m[1mwandb[0m: Currently logged in as: [33mggg4[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [21]:
from tqdm.auto import tqdm
def save_imm(epoch, step, saved=[]):
    fname = f"step-{epoch}-{step}.bin"
    torch.save(model.state_dict(), f"step-{epoch}-{step}.bin")
    saved.append(fname)
    if len(saved) > 5:
        delete_me = saved.pop(0)
        Path(delete_me).unlink(missing_ok=True)

def epoch_step(epoch, opt):
    for i, batch in enumerate(bar := tqdm(train_dl)):
        for k in batch:
            batch[k] = batch[k].to(device=model.lm_head.weight.device)
        
        n_batch, n_seq = batch["input_ids"].shape
        if n_seq > tokenizer.model_max_length:
            assert tokenizer.padding_side == "right", "Left-pad truncation only supported[as model should not see >2k token anyway]"
            batch["input_ids"] = batch["input_ids"][:, -tokenizer.model_max_length]
            batch["labels"] = batch["labels"][:, -tokenizer.model_max_length]
            batch["attention_mask"] = batch["attention_mask"][:, -tokenizer.model_max_length]
            
        
        loss = model(**batch).loss
        loss.backward()
        opt.step()
        opt.zero_grad()
        bar.set_description(f'L:{loss.item():.4f}')
        wandb.log({"loss": loss.item()})
        if (i+1) % 100 == 0:
            save_imm(epoch, i+1)
        
    torch.save(model.state_dict(), f"epoch-{epoch}.bin")


In [22]:
opt = torch.optim.AdamW(model.parameters(), fused=True)


In [None]:
for e in range(EPOCHS):
    epoch_step(e+1, opt)

  0%|          | 0/169865 [00:00<?, ?it/s]

In [45]:
!free -h

               total        used        free      shared  buff/cache   available
Mem:            85Gi       1.5Gi        72Gi       8.0Mi        11Gi        83Gi
Swap:             0B          0B          0B


In [65]:
!nvidia-smi

Fri Jul  7 17:44:05 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-PCI...  On   | 00000000:05:00.0 Off |                    0 |
| N/A   30C    P0    34W / 250W |   5739MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

Parameter containing:
tensor([[ 8.3618e-03,  3.8330e-02, -5.9204e-03,  ...,  2.0752e-02,
          4.4861e-03,  1.2512e-02],
        [ 3.9978e-03,  2.1118e-02, -3.5645e-02,  ..., -1.6846e-02,
          5.0659e-03, -3.8818e-02],
        [-1.6928e-05, -1.2756e-02, -1.1536e-02,  ..., -1.6235e-02,
          4.8218e-03, -1.4099e-02],
        ...,
        [-9.8267e-03, -6.8665e-03,  1.0864e-02,  ..., -1.0864e-02,
         -2.4170e-02, -5.6076e-04],
        [-9.5749e-04,  7.3853e-03,  4.9438e-03,  ...,  1.2390e-02,
         -2.1606e-02, -9.2163e-03],
        [ 5.1758e-02,  2.1484e-02, -1.5381e-02,  ..., -2.4292e-02,
         -3.4912e-02,  3.0823e-03]], device='cuda:0', dtype=torch.bfloat16,
       requires_grad=True)