GenAI_project / data /dataset.py
jaothan's picture
Upload 24 files
fa64206 verified
from datasets import load_dataset
from transformers import BertTokenizer
def load_and_tokenize_data(config):
"""
Load and tokenize data based on the provided configuration.
Args:
config (dict): Configuration dictionary containing dataset and tokenizer details.
Returns:
tuple: A tuple containing the tokenized train and test datasets.
"""
# Load the dataset
dataset = load_dataset(config['dataset']['name'], split=config['dataset']['split'])
dataset = dataset.train_test_split(test_size=0.2)
train_dataset = dataset['train']
test_dataset = dataset['test']
# Initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Define the tokenization function
def tokenize_function(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True)
# Apply tokenization to the train and test datasets
train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)
# Set the format to PyTorch tensors
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
return train_dataset, test_dataset
# Example usage
if __name__ == "__main__":
config = {
'dataset': {
'name': 'imdb',
'split': 'train[:10%]'
}
}
train_dataset, test_dataset = load_and_tokenize_data(config)
print("Train dataset and Test dataset have been loaded and tokenized successfully.")