whisper-ner-v1 / README.md
aiola's picture
Update README.md
5fe2021 verified
|
raw
history blame
2.68 kB
metadata
license: mit
datasets:
  - numind/NuNER
language:
  - en
pipeline_tag: zero-shot-classification
tags:
  - asr
  - Automatic Speech Recognition
  - Whisper
  - Ner
  - Named entity recognition

Whisper-NER

We introduce WhisperNER, a novel model that allows joint speech transcription and entity recognition. WhisperNER supports open-type NER, enabling recognition of diverse and evolving entities at inference.


Training Details

aiola/whisper-ner-v1 was trained on the NuNER dataset to perform joint audio transcription and NER tagging. The model was trained and evaluated only on English data. Check out the paper for full details.


Usage

To use whisper-ner-v1 install whisper-ner repo following the README instructions.

Inference can be done using the following code:

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration

model_path = "aiola/whisper-ner-v1"
audio_file_path = "path/to/audio/file"
prompt = "person, company, location"  # comma separated entity tags
    
# load model and processor from pre-trained
processor = WhisperProcessor.from_pretrained(model_path)
model = WhisperForConditionalGeneration.from_pretrained(model_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# load audio file: user is responsible for loading the audio files themselves
target_sample_rate = 16000
signal, sampling_rate = torchaudio.load(audio_file_path)
resampler = torchaudio.transforms.Resample(sampling_rate, target_sample_rate)
signal = resampler(signal)
# convert to mono or remove first dim if needed
if signal.ndim == 2:
    signal = torch.mean(signal, dim=0)
# pre-process to get the input features
input_features = processor(
    signal, sampling_rate=target_sample_rate, return_tensors="pt"
).input_features
input_features = input_features.to(device)

prompt_ids = processor.get_prompt_ids(prompt.lower(), return_tensors="pt")
prompt_ids = prompt_ids.to(device)

# generate token ids by running model forward sequentially
with torch.no_grad():
    predicted_ids = model.generate(
        input_features,
        prompt_ids=prompt_ids,
        generation_config=model.generation_config,
        language="en",
    )

# post-process token ids to text, remove prompt
transcription = processor.batch_decode(
    predicted_ids[:, prompt_ids.shape[0]:], skip_special_tokens=True
)[0]
print(transcription)