Spaces:
Running
Running
File size: 2,337 Bytes
6de3e11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import re
from dataclasses import dataclass
from typing import Any, List, Dict, Union
import torch
from zhconv import convert
# 删除标点符号
def remove_punctuation(text: str or List[str]):
punctuation = '!,.;:?、!,。;:?'
if isinstance(text, str):
text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
return text
elif isinstance(text, list):
result_text = []
for t in text:
t = re.sub(r'[{}]+'.format(punctuation), '', t).strip()
result_text.append(t)
return result_text
else:
raise Exception(f'不支持该类型{type(text)}')
# 将繁体中文总成简体中文
def to_simple(text: str or List[str]):
if isinstance(text, str):
text = convert(text, 'zh-cn')
return text
elif isinstance(text, list):
result_text = []
for t in text:
t = convert(t, 'zh-cn')
result_text.append(t)
return result_text
else:
raise Exception(f'不支持该类型{type(text)}')
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"][0]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
|