OpenSLU / model /decoder /decoder_utils.py
LightChen2333's picture
Upload 34 files
37b9e99
raw
history blame
6.44 kB
from typing import List
import torch
from common import utils
from common.utils import OutputData, InputData
from torch import Tensor
def argmax_for_seq_len(inputs, seq_lens, padding_value=-100):
packed_inputs = utils.pack_sequence(inputs, seq_lens)
outputs = torch.argmax(packed_inputs, dim=-1, keepdim=True)
return utils.unpack_sequence(outputs, seq_lens, padding_value).squeeze(-1)
def decode(output: OutputData,
target: InputData = None,
pred_type="slot",
multi_threshold=0.5,
ignore_index=-100,
return_list=True,
return_sentence_level=True,
use_multi=False,
use_crf=False,
CRF=None) -> List or Tensor:
""" decode output logits
Args:
output (OutputData): output logits data
target (InputData, optional): input data with attention mask. Defaults to None.
pred_type (str, optional): prediction type in ["slot", "intent", "token-level-intent"]. Defaults to "slot".
multi_threshold (float, optional): multi intent decode threshold. Defaults to 0.5.
ignore_index (int, optional): align and pad token with ignore index. Defaults to -100.
return_list (bool, optional): if True return list else return torch Tensor. Defaults to True.
return_sentence_level (bool, optional): if True decode sentence level intent else decode token level intent. Defaults to True.
use_multi (bool, optional): whether to decode to multi intent. Defaults to False.
use_crf (bool, optional): whether to use crf. Defaults to False.
CRF (CRF, optional): CRF function. Defaults to None.
Returns:
List or Tensor: decoded sequence ids
"""
if pred_type == "slot":
inputs = output.slot_ids
else:
inputs = output.intent_ids
if pred_type == "slot":
if not use_multi:
if use_crf:
res = CRF.decode(inputs, mask=target.attention_mask)
else:
res = torch.argmax(inputs, dim=-1)
else:
raise NotImplementedError("Multi-slot prediction is not supported.")
elif pred_type == "intent":
if not use_multi:
res = torch.argmax(inputs, dim=-1)
else:
res = (torch.sigmoid(inputs) > multi_threshold).nonzero()
if return_list:
res_index = res.detach().cpu().tolist()
res_list = [[] for _ in range(len(target.seq_lens))]
for item in res_index:
res_list[item[0]].append(item[1])
return res_list
else:
return res
elif pred_type == "token-level-intent":
if not use_multi:
res = torch.argmax(inputs, dim=-1)
if not return_sentence_level:
return res
if return_list:
res = res.detach().cpu().tolist()
attention_mask = target.attention_mask
for i in range(attention_mask.shape[0]):
temp = []
for j in range(attention_mask.shape[1]):
if attention_mask[i][j] == 1:
temp.append(res[i][j])
else:
break
res[i] = temp
return [max(it, key=lambda v: it.count(v)) for it in res]
else:
seq_lens = target.seq_lens
if not return_sentence_level:
token_res = torch.cat([
torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold
for i in range(len(seq_lens))],
dim=0)
return utils.unpack_sequence(token_res, seq_lens, padding_value=ignore_index)
intent_index_sum = torch.cat([
torch.sum(torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold, dim=0).unsqueeze(0)
for i in range(len(seq_lens))],
dim=0)
res = (intent_index_sum > torch.div(seq_lens, 2, rounding_mode='floor').unsqueeze(1)).nonzero()
if return_list:
res_index = res.detach().cpu().tolist()
res_list = [[] for _ in range(len(seq_lens))]
for item in res_index:
res_list[item[0]].append(item[1])
return res_list
else:
return res
else:
raise NotImplementedError("Prediction mode except ['slot','intent','token-level-intent'] is not supported.")
if return_list:
res = res.detach().cpu().tolist()
return res
def compute_loss(pred: OutputData,
target: InputData,
criterion_type="slot",
use_crf=False,
ignore_index=-100,
loss_fn=None,
use_multi=False,
CRF=None):
""" compute loss
Args:
pred (OutputData): output logits data
target (InputData): input golden data
criterion_type (str, optional): criterion type in ["slot", "intent", "token-level-intent"]. Defaults to "slot".
ignore_index (int, optional): compute loss with ignore index. Defaults to -100.
loss_fn (_type_, optional): loss function. Defaults to None.
use_crf (bool, optional): whether to use crf. Defaults to False.
CRF (CRF, optional): CRF function. Defaults to None.
Returns:
Tensor: loss result
"""
if criterion_type == "slot":
if use_crf:
return -1 * CRF(pred.slot_ids, target.slot, target.get_slot_mask(ignore_index).byte())
else:
pred_slot = utils.pack_sequence(pred.slot_ids, target.seq_lens)
target_slot = utils.pack_sequence(target.slot, target.seq_lens)
return loss_fn(pred_slot, target_slot)
elif criterion_type == "token-level-intent":
# TODO: Two decode function
intent_target = target.intent.unsqueeze(1)
if not use_multi:
intent_target = intent_target.repeat(1, pred.intent_ids.shape[1])
else:
intent_target = intent_target.repeat(1, pred.intent_ids.shape[1], 1)
intent_pred = utils.pack_sequence(pred.intent_ids, target.seq_lens)
intent_target = utils.pack_sequence(intent_target, target.seq_lens)
return loss_fn(intent_pred, intent_target)
else:
return loss_fn(pred.intent_ids, target.intent)