File size: 3,669 Bytes
26d475a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cc51df
 
26d475a
 
 
 
 
9cc51df
26d475a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eab759c
9cc51df
 
 
 
 
26d475a
9cc51df
26d475a
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import json
import pandas as pd
import datasets
import numpy as np
import evaluate
import torch
from transformers import AutoModel, DistilBertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
from typing import Optional

SEP_TOKEN = '[SEP]'
LABEL2NUM = {'entailment': 1, 'neutral': 0.5, 'contradiction': 0}

def format_dataset(arr):
    text = [el['sentence1'] + SEP_TOKEN + el['sentence2'] for el in arr]
    label = [LABEL2NUM[el['label']] for el in arr]
    new_df = pd.DataFrame({'text': text, 'label': label})
    return new_df.sample(frac=1, random_state=42).reset_index(drop=True)

# Load dataset
def load_dataset(path):
    train_array = []
    with open(path) as f:
        for line in f.readlines():
            if line:
                train_array.append(json.loads(line))
    df = format_dataset(train_array)
    # Split dataset into train and val
    df_train = df.iloc[512:, :]
    # We do not need much test data
    df_test = df.iloc[:512, :]
    print(df_train[:10])
    print(df_test[:10])

    factual_consistency_dataset = datasets.dataset_dict.DatasetDict()
    factual_consistency_dataset["train"] = datasets.dataset_dict.Dataset.from_pandas(
        df_train[["text", "label"]])
    factual_consistency_dataset["test"] = datasets.dataset_dict.Dataset.from_pandas(
        df_test[["text", "label"]])

    return factual_consistency_dataset


class ConsistentSentenceRegressor(DistilBertForSequenceClassification):

    def __init__(self, freeze_bert=True):
        base_model = AutoModel.from_pretrained(
            'line-corporation/line-distilbert-base-japanese')

        config = base_model.config
        config.problem_type = "regression"
        config.num_labels = 1
        super(ConsistentSentenceRegressor, self).__init__(config=config)

        self.distilbert = base_model

        # Replace the classifier with a single-neuron linear layer for regression
        self.classifier = torch.nn.Linear(config.dim, config.num_labels)

        self.loss_fn = torch.nn.MSELoss()
        
        if not freeze_bert:
            return

        for param in self.distilbert.parameters():
            param.requires_grad = False
                

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

        logits = outputs.logits.squeeze(-1)  # Remove the last dimension to match target tensor shape
        outputs.logits = logits
        if labels is not None:
            # Compute custom loss
            loss = self.loss_fn(logits, labels)
            outputs.loss = loss

        return outputs


# Set up evaluation metridef get_metrics():

def get_metrics():
    metric = evaluate.load("mse")

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        print(predictions.shape)
        print(labels.shape)
        return metric.compute(predictions=predictions, references=labels)

    return compute_metrics