AshwinSankar commited on
Commit
49a5c13
·
verified ·
1 Parent(s): 5ea365d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -10
app.py CHANGED
@@ -67,28 +67,74 @@ DEFAULT_TARGET_LANGUAGE = "Bengali"
67
 
68
  @spaces.GPU
69
  def run_asr_ctc(input_audio: str, target_language: str) -> str:
70
- # preprocess_audio(input_audio)
71
- # input_audio, orig_freq = torchaudio.load(input_audio)
72
- # input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000)
73
  lang_id = LANGUAGE_NAME_TO_CODE[target_language]
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  model.cur_decoder = "ctc"
76
- ctc_text = model.transcribe([input_audio], batch_size=1, logprobs=False, language_id=lang_id)[0]
77
-
78
  return ctc_text[0]
79
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  @spaces.GPU
81
  def run_asr_rnnt(input_audio: str, target_language: str) -> str:
82
- # preprocess_audio(input_audio)
83
- # input_audio, orig_freq = torchaudio.load(input_audio)
84
- # input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000)
85
  lang_id = LANGUAGE_NAME_TO_CODE[target_language]
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  model.cur_decoder = "rnnt"
88
- ctc_text = model.transcribe([input_audio], batch_size=1,logprobs=False, language_id=lang_id)[0]
89
-
90
  return ctc_text[0]
91
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  with gr.Blocks() as demo_asr_ctc:
 
67
 
68
  @spaces.GPU
69
  def run_asr_ctc(input_audio: str, target_language: str) -> str:
 
 
 
70
  lang_id = LANGUAGE_NAME_TO_CODE[target_language]
71
 
72
+ # Load and preprocess audio
73
+ audio_tensor, orig_freq = torchaudio.load(input_audio)
74
+
75
+ # Convert to mono if not already
76
+ if audio_tensor.shape[0] > 1:
77
+ audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True)
78
+
79
+ # Ensure shape [B x T]
80
+ if len(audio_tensor.shape) == 1:
81
+ audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension if missing
82
+
83
+ # Resample to 16kHz
84
+ audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=orig_freq, new_freq=16000)
85
+
86
  model.cur_decoder = "ctc"
87
+ ctc_text = model.transcribe([audio_tensor.numpy()], batch_size=1, logprobs=False, language_id=lang_id)[0]
88
+
89
  return ctc_text[0]
90
 
91
+ # @spaces.GPU
92
+ # def run_asr_ctc(input_audio: str, target_language: str) -> str:
93
+ # # preprocess_audio(input_audio)
94
+ # # input_audio, orig_freq = torchaudio.load(input_audio)
95
+ # # input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000)
96
+ # lang_id = LANGUAGE_NAME_TO_CODE[target_language]
97
+
98
+ # model.cur_decoder = "ctc"
99
+ # ctc_text = model.transcribe([input_audio], batch_size=1, logprobs=False, language_id=lang_id)[0]
100
+
101
+ # return ctc_text[0]
102
+
103
  @spaces.GPU
104
  def run_asr_rnnt(input_audio: str, target_language: str) -> str:
 
 
 
105
  lang_id = LANGUAGE_NAME_TO_CODE[target_language]
106
 
107
+ # Load and preprocess audio
108
+ audio_tensor, orig_freq = torchaudio.load(input_audio)
109
+
110
+ # Convert to mono if not already
111
+ if audio_tensor.shape[0] > 1:
112
+ audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True)
113
+
114
+ # Ensure shape [B x T]
115
+ if len(audio_tensor.shape) == 1:
116
+ audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension if missing
117
+
118
+ # Resample to 16kHz
119
+ audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=orig_freq, new_freq=16000)
120
+
121
  model.cur_decoder = "rnnt"
122
+ ctc_text = model.transcribe([audio_tensor.numpy()], batch_size=1, logprobs=False, language_id=lang_id)[0]
123
+
124
  return ctc_text[0]
125
 
126
+ # @spaces.GPU
127
+ # def run_asr_rnnt(input_audio: str, target_language: str) -> str:
128
+ # # preprocess_audio(input_audio)
129
+ # # input_audio, orig_freq = torchaudio.load(input_audio)
130
+ # # input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000)
131
+ # lang_id = LANGUAGE_NAME_TO_CODE[target_language]
132
+
133
+ # model.cur_decoder = "rnnt"
134
+ # ctc_text = model.transcribe([input_audio], batch_size=1,logprobs=False, language_id=lang_id)[0]
135
+
136
+ # return ctc_text[0]
137
+
138
 
139
 
140
  with gr.Blocks() as demo_asr_ctc: