|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import tiktoken |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
class GPTDatasetV1(Dataset): |
|
def __init__(self, txt, tokenizer, max_length, stride): |
|
self.input_ids = [] |
|
self.target_ids = [] |
|
|
|
|
|
token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"}) |
|
|
|
|
|
for i in range(0, len(token_ids) - max_length, stride): |
|
input_chunk = token_ids[i:i + max_length] |
|
target_chunk = token_ids[i + 1: i + max_length + 1] |
|
self.input_ids.append(torch.tensor(input_chunk)) |
|
self.target_ids.append(torch.tensor(target_chunk)) |
|
|
|
def __len__(self): |
|
return len(self.input_ids) |
|
|
|
def __getitem__(self, idx): |
|
return self.input_ids[idx], self.target_ids[idx] |
|
|
|
|
|
def create_dataloader_v1(txt, batch_size=4, max_length=256, |
|
stride=128, shuffle=True, drop_last=True, |
|
num_workers=0): |
|
|
|
|
|
tokenizer = tiktoken.get_encoding("gpt2") |
|
|
|
|
|
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride) |
|
|
|
|
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=shuffle, |
|
drop_last=drop_last, |
|
num_workers=num_workers |
|
) |
|
|
|
return dataloader |