jiuuee commited on
Commit
3bdfcbc
1 Parent(s): e6d8983

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -57
app.py CHANGED
@@ -18,37 +18,12 @@ import uuid
18
 
19
  import torch
20
 
21
- from nemo.collections.asr.models import ASRModel
22
- from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
23
- from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
24
 
25
  SAMPLE_RATE = 16000 # Hz
26
  MAX_AUDIO_SECS = 30 # wont try to transcribe if longer than this
27
-
28
- model = ASRModel.from_pretrained("nvidia/canary-1b")
29
- model.eval()
30
-
31
- # make sure beam size always 1 for consistency
32
- model.change_decoding_strategy(None)
33
- decoding_cfg = model.cfg.decoding
34
- decoding_cfg.beam.beam_size = 1
35
- model.change_decoding_strategy(decoding_cfg)
36
-
37
- # setup for buffered inference
38
- model.cfg.preprocessor.dither = 0.0
39
- model.cfg.preprocessor.pad_to = 0
40
-
41
- feature_stride = model.cfg.preprocessor['window_stride']
42
- model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer
43
-
44
- frame_asr = FrameBatchMultiTaskAED(
45
- asr_model=model,
46
- frame_len=40.0,
47
- total_buffer=40.0,
48
- batch_size=16,
49
- )
50
-
51
- amp_dtype = torch.float16
52
 
53
  def convert_audio(audio_filepath, tmpdir, utt_id):
54
  """
@@ -78,50 +53,36 @@ def convert_audio(audio_filepath, tmpdir, utt_id):
78
  return out_filename, duration
79
 
80
 
81
- def transcribe(audio_filepath):
 
 
 
82
  if audio_filepath is None:
83
  raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
84
 
85
  utt_id = uuid.uuid4()
86
 
87
  with tempfile.TemporaryDirectory() as tmpdir:
88
- converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
89
-
90
  # Make manifest file and save
91
  manifest_data = {
92
- "audio_filepath": converted_audio_filepath,
93
- "source_lang": "en",
94
- "target_lang": "en",
95
- "taskname": "asr",
96
- "pnc": "no",
97
- "answer": "predict",
98
- "duration": 10,
99
  }
100
 
101
  manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
102
 
103
  with open(manifest_filepath, 'w') as fout:
104
- json.dump(manifest_data, fout) # Fix: using json.dump to write manifest data
105
-
106
- # Call transcribe, passing in manifest filepath
107
- if duration < 40:
108
- output_text = model.transcribe(manifest_filepath)[0]
109
- else: # Do buffered inference
110
- with torch.cuda.amp.autocast(dtype=amp_dtype): # TODO: make it work if no cuda
111
- with torch.no_grad():
112
- hyps = get_buffered_pred_feat_multitaskAED(
113
- frame_asr,
114
- model.cfg.preprocessor,
115
- model_stride_in_secs,
116
- model.device,
117
- manifest=manifest_filepath,
118
- filepaths=None,
119
- )
120
-
121
- output_text = hyps[0].text
122
 
123
- return output_text
 
 
124
 
 
125
 
126
 
127
 
 
18
 
19
  import torch
20
 
 
 
 
21
 
22
  SAMPLE_RATE = 16000 # Hz
23
  MAX_AUDIO_SECS = 30 # wont try to transcribe if longer than this
24
+ src_lang = "en"
25
+ tgt_lang = "en"
26
+ pnc="no"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def convert_audio(audio_filepath, tmpdir, utt_id):
29
  """
 
53
  return out_filename, duration
54
 
55
 
56
+ # Load the ASR pipeline
57
+ asr_pipeline = pipeline("automatic-speech-recognition", model="nvidia/canary-1b")
58
+
59
+ def transcribe(audio_filepath, src_lang, tgt_lang, pnc):
60
  if audio_filepath is None:
61
  raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
62
 
63
  utt_id = uuid.uuid4()
64
 
65
  with tempfile.TemporaryDirectory() as tmpdir:
 
 
66
  # Make manifest file and save
67
  manifest_data = {
68
+ "audio_filepath": audio_filepath,
69
+ "source_lang": src_lang,
70
+ "target_lang": tgt_lang,
71
+ "taskname": "asr", # Setting taskname to "asr"
72
+ "pnc": pnc,
73
+ "answer": "predict"
 
74
  }
75
 
76
  manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
77
 
78
  with open(manifest_filepath, 'w') as fout:
79
+ json.dump(manifest_data, fout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ # Transcribe audio using ASR pipeline
82
+ transcribed_text = asr_pipeline(audio_filepath)
83
+ output_text = transcribed_text[0]['transcription']
84
 
85
+ return output_text
86
 
87
 
88