data structure
data = {
'input_sample_list': [
{
'data':
torch.rand(bs, 3, 224, 224, dtype=torch.float32),
'invalid_mask':
None,
'modality':
'image',
'data_type': 'input',
'sample_info': {
'id': list(range(bs)),
'path': ['hah' for _ in range(bs)]
}
},
],
'target_sample_list': [],
'target_idx_list': [torch.randint(0, 1000, (bs, ))],
'target_set_list': ['ImageNet22k'],
'shared_target_sets': {
'ImageNet22k': [{
'data':
torch.randint(0, 49411, (1000, 11)),
'invalid_mask':
torch.zeros(1000, 11, dtype=torch.bool),
'modality':
'text',
'data_type': 'target',
'sample_info': {
'distributed': True,
'total_num': 1000,
}
}]
},
'task_info': {
'task_name': 'imagenet',
'task_type': 'image_classification',
'dataset_name': 'ImageNet22k',
'batchsize': None,
'sampling_ratio': None
}
}
'input_sample_list': [
{
'data':
torch.rand(bs, 3, 224, 224, dtype=torch.float32),
'invalid_mask':
None,
'modality':
'image',
'data_type': 'input',
'sample_info': [{
'id': id,
'path': 'hahah',
'bs': bs
} for _ in range(bs)]
},
{
'data':
torch.randint(0, 49411, (bs, 31 * 2)),
'invalid_mask':
torch.zeros(bs, 31 * 2, dtype=torch.bool),
'modality':
'text',
'data_type': 'input',
'sample_info': [{
'pe_index':
torch.cat([torch.arange(31),
torch.arange(31)],
dim=0)
} for _ in range(bs)]
},
],
'target_sample_list': [],
'target_idx_list': [torch.randint(0, 49411, (bs, 31))],
'target_set_list': ['Vocab_Word'],
'shared_target_sets': {
'Vocab_Word': [{
'data': torch.randint(0, 49411, (49411, 2)),
'invalid_mask': None,
'modality': 'text',
'data_type': 'target',
'sample_info': {
'distributed': True,
'total_num': 49411,
}
}]
},
'task_info': {
'task_name': 'mscoco_caption',
'task_type': 'image_caption',
'dataset_name': 'MSCOCO',
'batchsize': None,
'sampling_ratio': None
}
}
data = {
'input_sample_list': [
{
'data': torch.randint(0, 49411, (bs, 128)),
'invalid_mask': torch.zeros(bs, 128, dtype=torch.bool),
'modality': 'text',
'data_type': 'input',
'sample_info': {
'seq_length': 128
}
},
],
'target_sample_list': [],
'target_idx_list': [torch.randint(0, 49411,
(bs, 128))], # most are -1,
'target_set_list': ['Vocab_Word'],
'shared_target_sets': {
'Vocab_Word': [{
'data': torch.randint(0, 49411, (49411, 2)),
'invalid_mask': None,
'modality': 'text',
'data_type': 'target',
'sample_info': {
'distributed': True,
'total_num': 49411,
}
}]
},
'task_info': {
'task_name': 'bookswiki_pretrain',
'task_type': 'text_mlm',
'dataset_name': 'BooksWiki',
'batchsize': None,
'sampling_ratio': None
}
}
data = {
'input_sample_list': [
{
'data':
torch.rand(bs, 3, 224, 224, dtype=torch.float32),
'invalid_mask':
None,
'modality':
'image',
'sample_info': {
'id': list(range(bs)),
'path': ['hah' for _ in range(bs)]
}
},
],
'target_sample_list': [
{
'data': torch.randint(0, 49411, (bs, 30)),
'invalid_mask': torch.zeros(bs, 30,
dtype=torch.bool),
'modality': 'text',
'sample_info': {}
},
],
'target_idx_list': [],
'target_set_list': [],
'shared_target_sets': {
'ImageNet22k': [{
'data':
torch.randint(0, 49411, (1000, 11)),
'invalid_mask':
torch.zeros(1000, 11, dtype=torch.bool),
'modality':
'text',
'sample_info': {
'distributed': True,
'total_num': 1000,
}
}]
},
'task_info': {
'task_name': 'mscoco_retrieve',
'task_type': 'image_retrieval',
'dataset_name': 'MSCOCO',
'batchsize': None,
'sampling_ratio': None
}
}