|
import json |
|
import pandas as pd |
|
import datasets |
|
import numpy as np |
|
import evaluate |
|
import torch |
|
from transformers import AutoModel, DistilBertForSequenceClassification |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from typing import Optional |
|
|
|
SEP_TOKEN = '[SEP]' |
|
LABEL2NUM = {'entailment': 1, 'neutral': 0.5, 'contradiction': 0} |
|
|
|
def format_dataset(arr): |
|
text = [el['sentence1'] + SEP_TOKEN + el['sentence2'] for el in arr] |
|
label = [LABEL2NUM[el['label']] for el in arr] |
|
new_df = pd.DataFrame({'text': text, 'label': label}) |
|
return new_df.sample(frac=1, random_state=42).reset_index(drop=True) |
|
|
|
|
|
def load_dataset(path): |
|
train_array = [] |
|
with open(path) as f: |
|
for line in f.readlines(): |
|
if line: |
|
train_array.append(json.loads(line)) |
|
df = format_dataset(train_array) |
|
|
|
df_train = df.iloc[512:, :] |
|
|
|
df_test = df.iloc[:512, :] |
|
print(df_train[:10]) |
|
print(df_test[:10]) |
|
|
|
factual_consistency_dataset = datasets.dataset_dict.DatasetDict() |
|
factual_consistency_dataset["train"] = datasets.dataset_dict.Dataset.from_pandas( |
|
df_train[["text", "label"]]) |
|
factual_consistency_dataset["test"] = datasets.dataset_dict.Dataset.from_pandas( |
|
df_test[["text", "label"]]) |
|
|
|
return factual_consistency_dataset |
|
|
|
|
|
class ConsistentSentenceRegressor(DistilBertForSequenceClassification): |
|
|
|
def __init__(self, freeze_bert=True): |
|
base_model = AutoModel.from_pretrained( |
|
'line-corporation/line-distilbert-base-japanese') |
|
|
|
config = base_model.config |
|
config.problem_type = "regression" |
|
config.num_labels = 1 |
|
super(ConsistentSentenceRegressor, self).__init__(config=config) |
|
|
|
self.distilbert = base_model |
|
|
|
|
|
self.classifier = torch.nn.Linear(config.dim, config.num_labels) |
|
|
|
self.loss_fn = torch.nn.MSELoss() |
|
|
|
if not freeze_bert: |
|
return |
|
|
|
for param in self.distilbert.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
outputs = super().forward( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
labels=labels, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict |
|
) |
|
|
|
logits = outputs.logits.squeeze(-1) |
|
outputs.logits = logits |
|
if labels is not None: |
|
|
|
loss = self.loss_fn(logits, labels) |
|
outputs.loss = loss |
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
def get_metrics(): |
|
metric = evaluate.load("mse") |
|
|
|
def compute_metrics(eval_pred): |
|
predictions, labels = eval_pred |
|
print(predictions.shape) |
|
print(labels.shape) |
|
return metric.compute(predictions=predictions, references=labels) |
|
|
|
return compute_metrics |
|
|