import torch device = torch.device("cuda") tenz = torch.tensor([1.,2.], device=device) #tenz.toDevice(device) print(torch.cuda.is_available()) from datasets import Dataset dataset = Dataset.from_dict({"a": [0, 1, 2]}) dataset_with_duplicates = dataset.map(lambda batch: {"b": batch["a"] * 2}) print(dataset_with_duplicates.shape) len(dataset_with_duplicates) dataset_with_duplicates[:]