File size: 1,389 Bytes
ff43e05
 
 
 
 
c731c61
 
ff43e05
c731c61
 
 
 
ff43e05
 
c731c61
 
 
 
 
 
ff43e05
 
c731c61
 
 
 
 
 
 
 
 
 
ff43e05
c731c61
 
 
 
 
 
ff43e05
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch


def compute_args(args):
    # DataLoader
    if not hasattr(args, "dataset"):  # fix for previous version
        args.dataset = "MOSEI"

    if args.dataset == "MOSEI":
        args.dataloader = "Mosei_Dataset"
    if args.dataset == "MELD":
        args.dataloader = "Meld_Dataset"

    # Loss function to use
    if args.dataset == "MOSEI" and args.task == "sentiment":
        args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
    if args.dataset == "MOSEI" and args.task == "emotion":
        args.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="sum")
    if args.dataset == "MELD":
        args.loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")

    # Answer size
    if args.dataset == "MOSEI" and args.task == "sentiment":
        args.ans_size = 7
    if args.dataset == "MOSEI" and args.task == "sentiment" and args.task_binary:
        args.ans_size = 2
    if args.dataset == "MOSEI" and args.task == "emotion":
        args.ans_size = 6
    if args.dataset == "MELD" and args.task == "emotion":
        args.ans_size = 7
    if args.dataset == "MELD" and args.task == "sentiment":
        args.ans_size = 3

    if args.dataset == "MOSEI":
        args.pred_func = "amax"
    if args.dataset == "MOSEI" and args.task == "emotion":
        args.pred_func = "multi_label"
    if args.dataset == "MELD":
        args.pred_func = "amax"

    return args