from typing import Literal, Union from datasets import load_dataset import torch from encoder import encode, tokens DatasetType = Union[None, str] _datasets: dict[str, DatasetType] = { 'train': None, 'validation': None, 'test': None, } # Lazy load the dataset def make_dataset(split: Literal['train', 'validation', 'test'] = 'train'): if _datasets[split] is None: ds = load_dataset( "karpathy/tiny_shakespeare", split=split, trust_remote_code=True) out = str(list(ds)[0]['text']) _datasets[split] = out return str(_datasets[split]) class Batcher(): def __init__(self, device: Literal['cuda', 'cpu'], batch_size: int, block_size: int): self.device = device self.batch_size = batch_size self.block_size = block_size from dataset import make_dataset train_data = make_dataset('train') val_data = make_dataset('validation') self.train_data = torch.tensor(encode(train_data), dtype=torch.long) self.val_data = torch.tensor(encode(val_data), dtype=torch.long) self.vocab = tokens def get_batch(self, split: str = 'val'): data = self.train_data if split == 'train' else self.val_data random_indexes = torch.randint( len(data) - self.block_size, (self.batch_size,)).to(self.device) context_stack = torch.stack( [data[i:i+self.block_size] for i in random_indexes]).to(self.device) answer_stack = torch.stack( [data[i+1:i+self.block_size+1] for i in random_indexes]).to(self.device) return context_stack, answer_stack