Harveenchadha commited on
Commit
8b8bd52
1 Parent(s): 90f5dcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -12
app.py CHANGED
@@ -22,22 +22,23 @@ def read_file(wav):
22
 
23
 
24
  def parse_transcription_with_lm(wav_file):
25
- speech = convert_file(wav_file)
26
- inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
27
 
28
  with torch.no_grad():
29
- logits = model(**inputs).logits
30
- int_result = processor.batch_decode(logits.cpu().numpy())
31
 
32
  transcription = int_result.text
33
  return transcription
34
 
35
 
36
- def convert_file(wav_file):
37
  filename = wav_file.split('.')[0]
38
  convert(wav_file, filename + "16k.wav")
39
  speech, _ = sf.read(filename + "16k.wav")
40
- return speech
 
 
41
 
42
  def parse(wav_file, applyLM):
43
  if applyLM:
@@ -46,12 +47,7 @@ def parse(wav_file, applyLM):
46
  return parse_transcription(wav_file)
47
 
48
  def parse_transcription(wav_file):
49
- speech = convert_file(wav_file)
50
-
51
-
52
- #speech = read_file(wav_file)
53
- input_values = processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
54
-
55
  logits = model(input_values).logits
56
  predicted_ids = torch.argmax(logits, dim=-1)
57
 
 
22
 
23
 
24
  def parse_transcription_with_lm(wav_file):
25
+ input_values = read_file_and_process(wav_file)
 
26
 
27
  with torch.no_grad():
28
+ logits = model(**input_values).logits
29
+ int_result = processor.decode(logits.cpu().numpy())
30
 
31
  transcription = int_result.text
32
  return transcription
33
 
34
 
35
+ def read_file_and_process(wav_file):
36
  filename = wav_file.split('.')[0]
37
  convert(wav_file, filename + "16k.wav")
38
  speech, _ = sf.read(filename + "16k.wav")
39
+ inputs = processor(speech, sampling_rate=16_000, return_tensors="pt", padding=True)
40
+
41
+ return inputs
42
 
43
  def parse(wav_file, applyLM):
44
  if applyLM:
 
47
  return parse_transcription(wav_file)
48
 
49
  def parse_transcription(wav_file):
50
+ input_values = read_file_and_process(wav_file)
 
 
 
 
 
51
  logits = model(input_values).logits
52
  predicted_ids = torch.argmax(logits, dim=-1)
53