DemoEmotions / utils /CustomDataset.py
karpurna2's picture
initial upload
127e34a
raw
history blame
1.19 kB
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, image, texts, labels, tokenizer, max_len, transforms=None):
self.image = image
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_len = max_len
self.transforms = transforms
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
image = self.image
text = str(self.texts[idx])
label = self.labels[idx]
if self.transforms:
image = self.transforms(image)
inputs = self.tokenizer.encode_plus(
text,
None,
add_special_tokens=True,
max_length=self.max_len,
padding='max_length',
truncation=True
)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
return {
'input_ids': torch.tensor(input_ids, dtype=torch.long),
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
'labels': torch.tensor(label, dtype=torch.float),
'images': image
}