aiola commited on
Commit
5fe2021
1 Parent(s): b12baee

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +43 -100
README.md CHANGED
@@ -13,18 +13,19 @@ tags:
13
  - Named entity recognition
14
  ---
15
 
16
- # Whisper Ner
17
 
18
- Whisper ner is an advanced model that allows joint speech transcription and entity recognition.
 
 
 
19
  WhisperNER supports open-type NER, enabling recognition of diverse and evolving entities at inference.
20
- We augment a large synthetic dataset with synthetic speech samples.
21
- This allows us to train WhisperNER on a large number of examples with diverse NER tags.
22
- During training, the model is prompted with NER labels and optimized to output the transcribed utterance along with the corresponding tagged entities.
23
 
24
  ---------
25
 
26
  ## Training Details
27
- `aiola/whisper-ner-v1` was trained on the Nuner dataset to perform audio translation with ner at the same time in English only.
 
28
 
29
  ---------
30
 
@@ -33,107 +34,49 @@ To use `whisper-ner-v1` install [`whisper-ner`](https://github.com/aiola-lab/whi
33
 
34
  Inference can be done using the following code:
35
  ```python
36
- import logging
37
- import argparse
38
  import torch
39
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
40
- from experiments.utils import set_logger, get_device, remove_suppress_tokens
41
- from experiments.utils.utils import UNSUPPRESS_TOKEN
42
- import torchaudio
43
- import numpy as np
44
-
45
- set_logger()
46
-
47
-
48
- @torch.no_grad()
49
- def main(model_path, audio_file_path, prompt, max_new_tokens, language, device):
50
- # load model and processor from pre-trained
51
- processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
52
- model = WhisperForConditionalGeneration.from_pretrained(model_path)
53
- remove_suppress_tokens(model)
54
- logging.info(f"removed suppress tokens: {UNSUPPRESS_TOKEN}")
55
-
56
- model = model.to(device)
57
 
58
- # load audio file: user is responsible for loading the audio files themselves
59
- target_sample_rate = 16000
60
- signal, sampling_rate = torchaudio.load(audio_file_path)
61
- resampler = torchaudio.transforms.Resample(sampling_rate, target_sample_rate)
62
- signal = resampler(signal)
63
- # convert to mono or remove first dim if needed
64
- if signal.ndim == 2:
65
- signal = torch.mean(signal, dim=0)
66
- # pre-process to get the input features
67
- input_features = processor(
68
- signal, sampling_rate=target_sample_rate, return_tensors="pt"
69
- ).input_features
70
- input_features = input_features.to(device)
71
-
72
- prompt = prompt.lower() # lowercase the prompt, to align with training
73
-
74
- prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt")
75
- prompt_ids = prompt_ids.to(device)
76
-
77
- # generate token ids by running model forward sequentially
78
- logging.info(f"Inference with prompt: '{prompt}'.")
 
 
 
 
 
 
 
 
 
79
  predicted_ids = model.generate(
80
  input_features,
81
- max_new_tokens=max_new_tokens,
82
- language=language,
83
  prompt_ids=prompt_ids,
84
  generation_config=model.generation_config,
 
85
  )
86
 
87
- # post-process token ids to text
88
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
89
- print(transcription)
90
-
91
-
92
- if __name__ == "__main__":
93
- parser = argparse.ArgumentParser(
94
- description="Transcribe audio using Whisper model."
95
- )
96
- parser.add_argument(
97
- "--model-path",
98
- type=str,
99
- required=True,
100
- default="aiola/whisper-ner-v1",
101
- help="Path to the pre-trained model components.",
102
- )
103
- parser.add_argument(
104
- "--audio-file-path",
105
- type=str,
106
- required=True,
107
- help="Path to the audio file (wav) to transcribe.",
108
- )
109
- parser.add_argument(
110
- "--prompt",
111
- type=str,
112
- default="father",
113
- help="Prompt text to guide the transcription.",
114
- )
115
- parser.add_argument(
116
- "--max-new-tokens",
117
- type=int,
118
- default=256,
119
- help="Maximum number of new tokens to generate.",
120
- )
121
- parser.add_argument(
122
- "--language",
123
- type=str,
124
- default="en",
125
- help="Language code for the transcription.",
126
- )
127
-
128
- args = parser.parse_args()
129
- device = get_device()
130
- main(
131
- args.model_path,
132
- args.audio_file_path,
133
- args.prompt,
134
- args.max_new_tokens,
135
- args.language,
136
- device,
137
- )
138
-
139
  ```
 
13
  - Named entity recognition
14
  ---
15
 
16
+ # Whisper-NER
17
 
18
+ - Peper: [_WhisperNER: Unified Open Named Entity and Speech Recognition_](https://arxiv.org/abs/2409.08107).
19
+ - Code: https://github.com/aiola-lab/whisper-ner
20
+
21
+ We introduce WhisperNER, a novel model that allows joint speech transcription and entity recognition.
22
  WhisperNER supports open-type NER, enabling recognition of diverse and evolving entities at inference.
 
 
 
23
 
24
  ---------
25
 
26
  ## Training Details
27
+ `aiola/whisper-ner-v1` was trained on the NuNER dataset to perform joint audio transcription and NER tagging.
28
+ The model was trained and evaluated only on English data. Check out the [paper](https://arxiv.org/abs/2409.08107) for full details.
29
 
30
  ---------
31
 
 
34
 
35
  Inference can be done using the following code:
36
  ```python
 
 
37
  import torch
38
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ model_path = "aiola/whisper-ner-v1"
41
+ audio_file_path = "path/to/audio/file"
42
+ prompt = "person, company, location" # comma separated entity tags
43
+
44
+ # load model and processor from pre-trained
45
+ processor = WhisperProcessor.from_pretrained(model_path)
46
+ model = WhisperForConditionalGeneration.from_pretrained(model_path)
47
+
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ model = model.to(device)
50
+
51
+ # load audio file: user is responsible for loading the audio files themselves
52
+ target_sample_rate = 16000
53
+ signal, sampling_rate = torchaudio.load(audio_file_path)
54
+ resampler = torchaudio.transforms.Resample(sampling_rate, target_sample_rate)
55
+ signal = resampler(signal)
56
+ # convert to mono or remove first dim if needed
57
+ if signal.ndim == 2:
58
+ signal = torch.mean(signal, dim=0)
59
+ # pre-process to get the input features
60
+ input_features = processor(
61
+ signal, sampling_rate=target_sample_rate, return_tensors="pt"
62
+ ).input_features
63
+ input_features = input_features.to(device)
64
+
65
+ prompt_ids = processor.get_prompt_ids(prompt.lower(), return_tensors="pt")
66
+ prompt_ids = prompt_ids.to(device)
67
+
68
+ # generate token ids by running model forward sequentially
69
+ with torch.no_grad():
70
  predicted_ids = model.generate(
71
  input_features,
 
 
72
  prompt_ids=prompt_ids,
73
  generation_config=model.generation_config,
74
+ language="en",
75
  )
76
 
77
+ # post-process token ids to text, remove prompt
78
+ transcription = processor.batch_decode(
79
+ predicted_ids[:, prompt_ids.shape[0]:], skip_special_tokens=True
80
+ )[0]
81
+ print(transcription)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ```