Spaces:
Build error
Build error
File size: 1,389 Bytes
ff43e05 c731c61 ff43e05 c731c61 ff43e05 c731c61 ff43e05 c731c61 ff43e05 c731c61 ff43e05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
def compute_args(args):
# DataLoader
if not hasattr(args, "dataset"): # fix for previous version
args.dataset = "MOSEI"
if args.dataset == "MOSEI":
args.dataloader = "Mosei_Dataset"
if args.dataset == "MELD":
args.dataloader = "Meld_Dataset"
# Loss function to use
if args.dataset == "MOSEI" and args.task == "sentiment":
args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
if args.dataset == "MOSEI" and args.task == "emotion":
args.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="sum")
if args.dataset == "MELD":
args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
# Answer size
if args.dataset == "MOSEI" and args.task == "sentiment":
args.ans_size = 7
if args.dataset == "MOSEI" and args.task == "sentiment" and args.task_binary:
args.ans_size = 2
if args.dataset == "MOSEI" and args.task == "emotion":
args.ans_size = 6
if args.dataset == "MELD" and args.task == "emotion":
args.ans_size = 7
if args.dataset == "MELD" and args.task == "sentiment":
args.ans_size = 3
if args.dataset == "MOSEI":
args.pred_func = "amax"
if args.dataset == "MOSEI" and args.task == "emotion":
args.pred_func = "multi_label"
if args.dataset == "MELD":
args.pred_func = "amax"
return args
|