|
import datetime |
|
import math |
|
import os |
|
|
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
from funasr import AutoModel |
|
from pyannote.audio import Audio, Pipeline |
|
from pyannote.core import Segment |
|
|
|
|
|
model = AutoModel( |
|
model="FunAudioLLM/SenseVoiceSmall", |
|
|
|
|
|
hub="hf", |
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
) |
|
|
|
pyannote_pipeline = Pipeline.from_pretrained( |
|
"pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_TOKEN") |
|
) |
|
if torch.cuda.is_available(): |
|
pyannote_pipeline.to(torch.device("cuda")) |
|
|
|
|
|
emo_dict = { |
|
"<|HAPPY|>": "๐", |
|
"<|SAD|>": "๐", |
|
"<|ANGRY|>": "๐ก", |
|
"<|NEUTRAL|>": "", |
|
"<|FEARFUL|>": "๐ฐ", |
|
"<|DISGUSTED|>": "๐คข", |
|
"<|SURPRISED|>": "๐ฎ", |
|
} |
|
|
|
event_dict = { |
|
"<|BGM|>": "๐ผ", |
|
"<|Speech|>": "", |
|
"<|Applause|>": "๐", |
|
"<|Laughter|>": "๐", |
|
"<|Cry|>": "๐ญ", |
|
"<|Sneeze|>": "๐คง", |
|
"<|Breath|>": "", |
|
"<|Cough|>": "๐คง", |
|
} |
|
|
|
emoji_dict = { |
|
"<|nospeech|><|Event_UNK|>": "โ", |
|
"<|zh|>": "", |
|
"<|en|>": "", |
|
"<|yue|>": "", |
|
"<|ja|>": "", |
|
"<|ko|>": "", |
|
"<|nospeech|>": "", |
|
"<|HAPPY|>": "๐", |
|
"<|SAD|>": "๐", |
|
"<|ANGRY|>": "๐ก", |
|
"<|NEUTRAL|>": "", |
|
"<|BGM|>": "๐ผ", |
|
"<|Speech|>": "", |
|
"<|Applause|>": "๐", |
|
"<|Laughter|>": "๐", |
|
"<|FEARFUL|>": "๐ฐ", |
|
"<|DISGUSTED|>": "๐คข", |
|
"<|SURPRISED|>": "๐ฎ", |
|
"<|Cry|>": "๐ญ", |
|
"<|EMO_UNKNOWN|>": "", |
|
"<|Sneeze|>": "๐คง", |
|
"<|Breath|>": "", |
|
"<|Cough|>": "๐ท", |
|
"<|Sing|>": "", |
|
"<|Speech_Noise|>": "", |
|
"<|withitn|>": "", |
|
"<|woitn|>": "", |
|
"<|GBG|>": "", |
|
"<|Event_UNK|>": "", |
|
} |
|
|
|
lang_dict = { |
|
"<|zh|>": "<|lang|>", |
|
"<|en|>": "<|lang|>", |
|
"<|yue|>": "<|lang|>", |
|
"<|ja|>": "<|lang|>", |
|
"<|ko|>": "<|lang|>", |
|
"<|nospeech|>": "<|lang|>", |
|
} |
|
|
|
emo_set = {"๐", "๐", "๐ก", "๐ฐ", "๐คข", "๐ฎ"} |
|
event_set = {"๐ผ", "๐", "๐", "๐ญ", "๐คง", "๐ท"} |
|
|
|
|
|
def clean_and_emoji_annotate_speech(text): |
|
|
|
def get_emoji(s, emoji_set): |
|
return next((char for char in s if char in emoji_set), None) |
|
|
|
|
|
def format_text_with_emojis(s): |
|
|
|
sptk_dict = {sptk: s.count(sptk) for sptk in emoji_dict} |
|
|
|
|
|
for sptk in emoji_dict: |
|
s = s.replace(sptk, "") |
|
|
|
|
|
emo = "<|NEUTRAL|>" |
|
for e in emo_dict: |
|
if sptk_dict.get(e, 0) > sptk_dict.get(emo, 0): |
|
emo = e |
|
|
|
|
|
s = ( |
|
"".join(event_dict[e] for e in event_dict if sptk_dict.get(e, 0) > 0) |
|
+ s |
|
+ emo_dict[emo] |
|
) |
|
|
|
|
|
for emoji in emo_set.union(event_set): |
|
s = s.replace(f" {emoji}", emoji).replace(f"{emoji} ", emoji) |
|
|
|
return s.strip() |
|
|
|
|
|
text = text.replace("<|nospeech|><|Event_UNK|>", "โ") |
|
for lang, replacement in lang_dict.items(): |
|
text = text.replace(lang, replacement) |
|
|
|
|
|
segments = [ |
|
format_text_with_emojis(segment.strip()) for segment in text.split("<|lang|>") |
|
] |
|
|
|
formatted_segments = [] |
|
prev_event = prev_emotion = None |
|
|
|
|
|
for segment in segments: |
|
if not segment: |
|
continue |
|
|
|
current_event = get_emoji(segment, event_set) |
|
current_emotion = get_emoji(segment, emo_set) |
|
|
|
|
|
if current_event is not None: |
|
segment = segment[1:] if segment.startswith(current_event) else segment |
|
|
|
|
|
if current_emotion is not None and current_emotion != prev_emotion: |
|
segment = segment.replace(current_emotion, "") + current_emotion |
|
|
|
formatted_segments.append(segment.strip()) |
|
prev_event, prev_emotion = current_event, current_emotion |
|
|
|
|
|
result = " ".join(formatted_segments).replace("The.", "").strip() |
|
return result |
|
|
|
|
|
def time_to_seconds(time_str): |
|
h, m, s = time_str.split(":") |
|
return round(int(h) * 3600 + int(m) * 60 + float(s), 9) |
|
|
|
|
|
def parse_time(time_str): |
|
|
|
time_str = time_str.rstrip("s") |
|
|
|
|
|
parts = time_str.split(":") |
|
|
|
if len(parts) == 3: |
|
h, m, s = parts |
|
elif len(parts) == 2: |
|
h = "0" |
|
m, s = parts |
|
else: |
|
h = m = "0" |
|
s = parts[0] |
|
|
|
return int(h) * 3600 + int(m) * 60 + float(s) |
|
|
|
|
|
def format_time(seconds, use_short_format=True, always_use_seconds=False): |
|
if isinstance(seconds, datetime.timedelta): |
|
seconds = seconds.total_seconds() |
|
|
|
minutes, seconds = divmod(seconds, 60) |
|
hours, minutes = divmod(int(minutes), 60) |
|
|
|
if always_use_seconds or (use_short_format and hours == 0 and minutes == 0): |
|
return f"{seconds:06.3f}s" |
|
elif use_short_format and hours == 0: |
|
return f"{minutes:02d}:{seconds:06.3f}" |
|
else: |
|
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}" |
|
|
|
|
|
def generate_diarization(audio_path): |
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
if not hf_token: |
|
raise ValueError( |
|
"HF_TOKEN environment variable is not set. Please set it with your Hugging Face token." |
|
) |
|
|
|
|
|
audio = Audio(sample_rate=16000, mono=True) |
|
|
|
|
|
pipeline = Pipeline.from_pretrained( |
|
"pyannote/speaker-diarization-3.1", use_auth_token=hf_token |
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
pipeline.to(torch.device("cuda")) |
|
|
|
|
|
file_path = audio_path |
|
|
|
if not os.path.exists(file_path): |
|
raise FileNotFoundError(f"Could not find the audio file at: {file_path}") |
|
|
|
print(f"Using audio file: {file_path}") |
|
|
|
|
|
waveform, sample_rate = audio(file_path) |
|
|
|
|
|
file = {"waveform": waveform, "sample_rate": sample_rate, "uri": "mtr"} |
|
|
|
|
|
output = pipeline(file) |
|
|
|
|
|
diarization_segments = [] |
|
txt_file = "mtr_dn.txt" |
|
with open(txt_file, "w") as f: |
|
current_speaker = None |
|
current_start = None |
|
current_end = None |
|
|
|
for turn, _, speaker in output.itertracks(yield_label=True): |
|
if speaker != current_speaker: |
|
if current_speaker is not None: |
|
start_time = format_time(current_start) |
|
end_time = format_time(current_end) |
|
duration = format_time(current_end - current_start) |
|
line = ( |
|
f"{start_time} - {end_time} ({duration}): {current_speaker}\n" |
|
) |
|
f.write(line) |
|
print(line.strip()) |
|
diarization_segments.append( |
|
( |
|
parse_time(start_time), |
|
parse_time(end_time), |
|
parse_time(duration), |
|
current_speaker, |
|
) |
|
) |
|
current_speaker = speaker |
|
current_start = turn.start |
|
current_end = turn.end |
|
else: |
|
current_end = turn.end |
|
|
|
|
|
if current_speaker is not None: |
|
start_time = format_time(current_start) |
|
end_time = format_time(current_end) |
|
duration = format_time(current_end - current_start) |
|
line = f"{start_time} - {end_time} ({duration}): {current_speaker}\n" |
|
f.write(line) |
|
print(line.strip()) |
|
diarization_segments.append( |
|
( |
|
parse_time(start_time), |
|
parse_time(end_time), |
|
parse_time(duration), |
|
current_speaker, |
|
) |
|
) |
|
|
|
print(f"\nHuman-readable diarization results saved to {txt_file}") |
|
return diarization_segments |
|
|
|
|
|
def process_audio(audio_path, language="yue", fs=16000): |
|
|
|
diarization_segments = generate_diarization(audio_path) |
|
|
|
|
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
if sample_rate != fs: |
|
resampler = torchaudio.transforms.Resample(sample_rate, fs) |
|
waveform = resampler(waveform) |
|
|
|
input_wav = waveform.mean(0).numpy() |
|
|
|
|
|
total_duration = sum(duration for _, _, duration, _ in diarization_segments) |
|
use_long_format = total_duration >= 60 |
|
|
|
|
|
results = [] |
|
for start_time, end_time, duration, speaker in diarization_segments: |
|
start_seconds = start_time |
|
end_seconds = end_time |
|
|
|
|
|
start_sample = int(start_seconds * fs) |
|
end_sample = int(end_seconds * fs) |
|
|
|
chunk = input_wav[start_sample:end_sample] |
|
try: |
|
text = model.generate( |
|
input=chunk, |
|
cache={}, |
|
language=language, |
|
use_itn=True, |
|
batch_size_s=500, |
|
merge_vad=True, |
|
) |
|
text = text[0]["text"] |
|
|
|
|
|
print(f"Text before clean_and_emoji_annotate_speech: {text}") |
|
|
|
text = clean_and_emoji_annotate_speech(text) |
|
|
|
|
|
if not text.strip(): |
|
text = "[inaudible]" |
|
|
|
results.append((speaker, start_time, end_time, duration, text)) |
|
except AssertionError as e: |
|
if "choose a window size" in str(e): |
|
print( |
|
f"Warning: Audio segment too short to process. Skipping. Error: {e}" |
|
) |
|
results.append((speaker, start_time, end_time, duration, "[too short]")) |
|
else: |
|
raise |
|
|
|
|
|
formatted_text = "" |
|
for speaker, start, end, duration, text in results: |
|
start_str = ( |
|
format_time(start, use_short_format=False) |
|
if use_long_format |
|
else format_time(start, use_short_format=True) |
|
) |
|
end_str = ( |
|
format_time(end, use_short_format=False) |
|
if use_long_format |
|
else format_time(end, use_short_format=True) |
|
) |
|
duration_str = format_time( |
|
duration, use_short_format=True |
|
) |
|
speaker_num = "1" if speaker == "SPEAKER_00" else "2" |
|
line = f"{start_str} - {end_str} ({duration_str}) Speaker {speaker_num}: {text}" |
|
formatted_text += line + "\n" |
|
print(f"Debug: Formatted line: {line}") |
|
|
|
print("Debug: Full formatted text:") |
|
print(formatted_text) |
|
return formatted_text.strip() |
|
|
|
|
|
if __name__ == "__main__": |
|
audio_path = "example/mtr.mp3" |
|
language = "yue" |
|
|
|
|
|
diarization_only = False |
|
|
|
if diarization_only: |
|
diarization_segments = generate_diarization(audio_path) |
|
|
|
else: |
|
result = process_audio(audio_path, language) |
|
|
|
|
|
output_path = "mtr.txt" |
|
with open(output_path, "w", encoding="utf-8") as f: |
|
f.write(result) |
|
|
|
print(f"Diarization and transcription result has been saved to {output_path}") |
|
|