File size: 405 Bytes
0416ac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch

device = torch.device("cuda")

tenz = torch.tensor([1.,2.], device=device)
#tenz.toDevice(device)

print(torch.cuda.is_available())

from datasets import Dataset

dataset = Dataset.from_dict({"a": [0, 1, 2]})
dataset_with_duplicates = dataset.map(lambda batch: {"b": batch["a"] * 2})
print(dataset_with_duplicates.shape)
len(dataset_with_duplicates)
dataset_with_duplicates[:]