|
--- |
|
license: mit |
|
datasets: |
|
- numind/NuNER |
|
language: |
|
- en |
|
pipeline_tag: automatic-speech-recognition |
|
tags: |
|
- asr |
|
- Automatic Speech Recognition |
|
- Whisper |
|
- Named entity recognition |
|
--- |
|
|
|
# Whisper-NER |
|
|
|
- Demo: https://huggingface.co/spaces/aiola/whisper-ner-v1 |
|
- Peper: [_WhisperNER: Unified Open Named Entity and Speech Recognition_](https://arxiv.org/abs/2409.08107). |
|
- Code: https://github.com/aiola-lab/whisper-ner |
|
|
|
We introduce WhisperNER, a novel model that allows joint speech transcription and entity recognition. |
|
WhisperNER supports open-type NER, enabling recognition of diverse and evolving entities at inference. The WhisperNER model is designed as a strong base model for the downstream task of ASR with NER, and can be fine-tuned on specific datasets for improved performance. |
|
|
|
**NOTE:** This model also support entity masking directly on the output transcript, especially relevant for PII use cases. However, the model was not trained on PII specific datasets, hence can perform general and open type entity masking, |
|
but **it should be further funetuned in order to be used for PII tasks**. |
|
|
|
|
|
--------- |
|
|
|
## Training Details |
|
`aiola/whisper-ner-tag-and-mask-v1` was finetuned from `aiola/whisper-ner-v1` using the NuNER dataset to perform joint audio transcription and NER tagging or NER masking. |
|
The model was trained and evaluated only on English data. Check out the [paper](https://arxiv.org/abs/2409.08107) for full details. |
|
|
|
--------- |
|
|
|
## Usage |
|
|
|
Inference can be done using the following code (for inference code and more details check out the [whisper-ner repo](https://github.com/aiola-lab/whisper-ner)).: |
|
|
|
```python |
|
import torch |
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration |
|
|
|
model_path = "aiola/whisper-ner-tag-and-mask-v1" |
|
audio_file_path = "path/to/audio/file" |
|
prompt = "person, company, location" # comma separated entity tags |
|
apply_entity_mask = False # change to True for entity masking |
|
mask_token = "<|mask|>" |
|
|
|
if apply_entity_mask: |
|
prompt = f"{mask_token}{prompt}" |
|
|
|
# load model and processor from pre-trained |
|
processor = WhisperProcessor.from_pretrained(model_path) |
|
model = WhisperForConditionalGeneration.from_pretrained(model_path) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
|
|
# load audio file: user is responsible for loading the audio files themselves |
|
target_sample_rate = 16000 |
|
signal, sampling_rate = torchaudio.load(audio_file_path) |
|
resampler = torchaudio.transforms.Resample(sampling_rate, target_sample_rate) |
|
signal = resampler(signal) |
|
# convert to mono or remove first dim if needed |
|
if signal.ndim == 2: |
|
signal = torch.mean(signal, dim=0) |
|
# pre-process to get the input features |
|
input_features = processor( |
|
signal, sampling_rate=target_sample_rate, return_tensors="pt" |
|
).input_features |
|
input_features = input_features.to(device) |
|
|
|
prompt_ids = processor.get_prompt_ids(prompt.lower(), return_tensors="pt") |
|
prompt_ids = prompt_ids.to(device) |
|
|
|
# generate token ids by running model forward sequentially |
|
with torch.no_grad(): |
|
predicted_ids = model.generate( |
|
input_features, |
|
prompt_ids=prompt_ids, |
|
generation_config=model.generation_config, |
|
language="en", |
|
) |
|
|
|
# post-process token ids to text, remove prompt |
|
transcription = processor.batch_decode( |
|
predicted_ids, skip_special_tokens=True |
|
)[0] |
|
print(transcription) |
|
``` |