whisper-ner-v1 / README.md
aiola's picture
Update README.md
5fe2021 verified
|
raw
history blame
2.68 kB
---
license: mit
datasets:
- numind/NuNER
language:
- en
pipeline_tag: zero-shot-classification
tags:
- asr
- Automatic Speech Recognition
- Whisper
- Ner
- Named entity recognition
---
# Whisper-NER
- Peper: [_WhisperNER: Unified Open Named Entity and Speech Recognition_](https://arxiv.org/abs/2409.08107).
- Code: https://github.com/aiola-lab/whisper-ner
We introduce WhisperNER, a novel model that allows joint speech transcription and entity recognition.
WhisperNER supports open-type NER, enabling recognition of diverse and evolving entities at inference.
---------
## Training Details
`aiola/whisper-ner-v1` was trained on the NuNER dataset to perform joint audio transcription and NER tagging.
The model was trained and evaluated only on English data. Check out the [paper](https://arxiv.org/abs/2409.08107) for full details.
---------
## Usage
To use `whisper-ner-v1` install [`whisper-ner`](https://github.com/aiola-lab/whisper-ner) repo following the README instructions.
Inference can be done using the following code:
```python
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
model_path = "aiola/whisper-ner-v1"
audio_file_path = "path/to/audio/file"
prompt = "person, company, location" # comma separated entity tags
# load model and processor from pre-trained
processor = WhisperProcessor.from_pretrained(model_path)
model = WhisperForConditionalGeneration.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# load audio file: user is responsible for loading the audio files themselves
target_sample_rate = 16000
signal, sampling_rate = torchaudio.load(audio_file_path)
resampler = torchaudio.transforms.Resample(sampling_rate, target_sample_rate)
signal = resampler(signal)
# convert to mono or remove first dim if needed
if signal.ndim == 2:
signal = torch.mean(signal, dim=0)
# pre-process to get the input features
input_features = processor(
signal, sampling_rate=target_sample_rate, return_tensors="pt"
).input_features
input_features = input_features.to(device)
prompt_ids = processor.get_prompt_ids(prompt.lower(), return_tensors="pt")
prompt_ids = prompt_ids.to(device)
# generate token ids by running model forward sequentially
with torch.no_grad():
predicted_ids = model.generate(
input_features,
prompt_ids=prompt_ids,
generation_config=model.generation_config,
language="en",
)
# post-process token ids to text, remove prompt
transcription = processor.batch_decode(
predicted_ids[:, prompt_ids.shape[0]:], skip_special_tokens=True
)[0]
print(transcription)
```