Spaces:
Sleeping
Sleeping
from datasets import load_dataset | |
from transformers import BertTokenizer | |
def load_and_tokenize_data(config): | |
""" | |
Load and tokenize data based on the provided configuration. | |
Args: | |
config (dict): Configuration dictionary containing dataset and tokenizer details. | |
Returns: | |
tuple: A tuple containing the tokenized train and test datasets. | |
""" | |
# Load the dataset | |
dataset = load_dataset(config['dataset']['name'], split=config['dataset']['split']) | |
dataset = dataset.train_test_split(test_size=0.2) | |
train_dataset = dataset['train'] | |
test_dataset = dataset['test'] | |
# Initialize the tokenizer | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
# Define the tokenization function | |
def tokenize_function(examples): | |
return tokenizer(examples['text'], padding='max_length', truncation=True) | |
# Apply tokenization to the train and test datasets | |
train_dataset = train_dataset.map(tokenize_function, batched=True) | |
test_dataset = test_dataset.map(tokenize_function, batched=True) | |
# Set the format to PyTorch tensors | |
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) | |
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) | |
return train_dataset, test_dataset | |
# Example usage | |
if __name__ == "__main__": | |
config = { | |
'dataset': { | |
'name': 'imdb', | |
'split': 'train[:10%]' | |
} | |
} | |
train_dataset, test_dataset = load_and_tokenize_data(config) | |
print("Train dataset and Test dataset have been loaded and tokenized successfully.") | |