sepal commited on
Commit
427085e
1 Parent(s): 60bd8ec

Add support for whisper openai API.

Browse files
Files changed (1) hide show
  1. app.py +21 -6
app.py CHANGED
@@ -13,11 +13,9 @@ import tempfile
13
  load_dotenv()
14
 
15
  hg_token = os.getenv("HG_ACCESS_TOKEN")
 
16
 
17
- if hg_token != None:
18
- pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hg_token)
19
- whisper_ml = whisper.load_model("base")
20
- else:
21
  print('''No hugging face access token set.
22
  You need to set it via an .env or environment variable HG_ACCESS_TOKEN''')
23
  exit(1)
@@ -27,6 +25,7 @@ def diarization(audio) -> np.array:
27
  """
28
  Receives a pydub AudioSegment and returns an numpy array with all segments.
29
  """
 
30
  audio.export("/tmp/dz.wav", format="wav")
31
  diarization = pipeline("/tmp/dz.wav")
32
  return pd.DataFrame(list(diarization.itertracks(yield_label=True)),columns=["Segment","Trackname", "Speaker"])
@@ -50,8 +49,24 @@ def prep_audio(audio_segment):
50
 
51
  def transcribe_row(row, audio):
52
  segment = audio[row.start_ms:row.end_ms]
53
- data = prep_audio(segment)
54
- return whisper_ml.transcribe(data)['text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  def combine_transcription(segments):
 
13
  load_dotenv()
14
 
15
  hg_token = os.getenv("HG_ACCESS_TOKEN")
16
+ open_api_key = os.getenv("OPENAI_API_KEY")
17
 
18
+ if hg_token == None:
 
 
 
19
  print('''No hugging face access token set.
20
  You need to set it via an .env or environment variable HG_ACCESS_TOKEN''')
21
  exit(1)
 
25
  """
26
  Receives a pydub AudioSegment and returns an numpy array with all segments.
27
  """
28
+ pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hg_token)
29
  audio.export("/tmp/dz.wav", format="wav")
30
  diarization = pipeline("/tmp/dz.wav")
31
  return pd.DataFrame(list(diarization.itertracks(yield_label=True)),columns=["Segment","Trackname", "Speaker"])
 
49
 
50
  def transcribe_row(row, audio):
51
  segment = audio[row.start_ms:row.end_ms]
52
+ if open_api_key == None:
53
+ whisper_ml = whisper.load_model("base")
54
+ data = prep_audio(segment)
55
+ return whisper_ml.transcribe(data)['text']
56
+ else:
57
+ print("Using openai API")
58
+ # the open ai whisper AI only accepts audio files with a length of at
59
+ # least 0.1 seconds.
60
+ if row['end_ms'] - row['start_ms'] < 100:
61
+ return ""
62
+ import openai
63
+ import tempfile
64
+ temp_file = f"/tmp/{row['Trackname']}.mp3"
65
+ segment.export(temp_file, format="mp3")
66
+ print(temp_file)
67
+ audio_file = open(temp_file, "rb")
68
+ return openai.Audio.translate("whisper-1", audio_file)['text']
69
+
70
 
71
 
72
  def combine_transcription(segments):