InterpreTalk / backend /seamless_utils.py
benjolo's picture
Upload 43 files
1778490 verified
raw
history blame
7.67 kB
# base seamless imports
# ---------------------------------
import io
import json
import matplotlib as mpl
import matplotlib.pyplot as plt
import mmap
import numpy as np
import soundfile
import torchaudio
import torch
from pydub import AudioSegment
# ---------------------------------
# seamless-streaming specific imports
# ---------------------------------
import math
from simuleval.data.segments import SpeechSegment, EmptySegment
from seamless_communication.streaming.agents.seamless_streaming_s2st import (
SeamlessStreamingS2STVADAgent,
)
from simuleval.utils.arguments import cli_argument_list
from simuleval import options
from typing import Union, List
from simuleval.data.segments import Segment, TextSegment
from simuleval.agents.pipeline import TreeAgentPipeline
from simuleval.agents.states import AgentStates
# ---------------------------------
# seamless setup
# source: https://colab.research.google.com/github/kauterry/seamless_communication/blob/main/Seamless_Tutorial.ipynb?
SAMPLE_RATE = 16000
# PM - THis class is used to simulate the audio frontend in the seamless streaming pipeline
# need to replace this with the actual audio frontend
# TODO: replacement class that takes in PCM-16 bytes and returns SpeechSegment
class AudioFrontEnd:
def __init__(self, wav_file, segment_size) -> None:
self.samples, self.sample_rate = soundfile.read(wav_file)
print(self.sample_rate, "sample rate")
assert self.sample_rate == SAMPLE_RATE
# print(len(self.samples), self.samples[:100])
self.samples = self.samples # .tolist()
self.segment_size = segment_size
self.step = 0
def send_segment(self):
"""
This is the front-end logic in simuleval instance.py
"""
num_samples = math.ceil(self.segment_size / 1000 * self.sample_rate)
if self.step < len(self.samples):
if self.step + num_samples >= len(self.samples):
samples = self.samples[self.step :]
is_finished = True
else:
samples = self.samples[self.step : self.step + num_samples]
is_finished = False
self.samples = self.samples[self.step:]
self.step = min(self.step + num_samples, len(self.samples))
segment = SpeechSegment(
content=samples,
sample_rate=self.sample_rate,
finished=is_finished,
)
else:
# Finish reading this audio
segment = EmptySegment(
finished=True,
)
self.step = 0
self.samples = []
return segment
# samples = self.samples[:num_samples]
# self.samples = self.samples[num_samples:]
# segment = SpeechSegment(
# content=samples,
# sample_rate=self.sample_rate,
# finished=False,
# )
def add_segments(self, wav):
new_samples, _ = soundfile.read(wav)
self.samples = np.concatenate((self.samples, new_samples))
class OutputSegments:
def __init__(self, segments: Union[List[Segment], Segment]):
if isinstance(segments, Segment):
segments = [segments]
self.segments: List[Segment] = [s for s in segments]
@property
def is_empty(self):
return all(segment.is_empty for segment in self.segments)
@property
def finished(self):
return all(segment.finished for segment in self.segments)
def get_audiosegment(samples, sr):
b = io.BytesIO()
soundfile.write(b, samples, samplerate=sr, format="wav")
b.seek(0)
return AudioSegment.from_file(b)
def reset_states(system, states):
if isinstance(system, TreeAgentPipeline):
states_iter = states.values()
else:
states_iter = states
for state in states_iter:
state.reset()
def get_states_root(system, states) -> AgentStates:
if isinstance(system, TreeAgentPipeline):
# self.states is a dict
return states[system.source_module]
else:
# self.states is a list
return system.states[0]
def build_streaming_system(model_configs, agent_class):
parser = options.general_parser()
parser.add_argument("-f", "--f", help="a dummy argument to fool ipython", default="1")
agent_class.add_args(parser)
args, _ = parser.parse_known_args(cli_argument_list(model_configs))
system = agent_class.from_args(args)
return system
def run_streaming_inference(system, audio_frontend, system_states, tgt_lang):
# NOTE: Here for visualization, we calculate delays offset from audio
# *BEFORE* VAD segmentation.
# In contrast for SimulEval evaluation, we assume audios are pre-segmented,
# and Average Lagging, End Offset metrics are based on those pre-segmented audios.
# Thus, delays here are *NOT* comparable to SimulEval per-segment delays
delays = {"s2st": [], "s2tt": []}
prediction_lists = {"s2st": [], "s2tt": []}
speech_durations = []
curr_delay = 0
target_sample_rate = None
while True:
input_segment = audio_frontend.send_segment()
input_segment.tgt_lang = tgt_lang
curr_delay += len(input_segment.content) / SAMPLE_RATE * 1000
if input_segment.finished:
# a hack, we expect a real stream to end with silence
get_states_root(system, system_states).source_finished = True
# Translation happens here
if isinstance(input_segment, EmptySegment):
return None, None, None, None
output_segments = OutputSegments(system.pushpop(input_segment, system_states))
if not output_segments.is_empty:
for segment in output_segments.segments:
# NOTE: another difference from SimulEval evaluation -
# delays are accumulated per-token
if isinstance(segment, SpeechSegment):
pred_duration = 1000 * len(segment.content) / segment.sample_rate
speech_durations.append(pred_duration)
delays["s2st"].append(curr_delay)
prediction_lists["s2st"].append(segment.content)
target_sample_rate = segment.sample_rate
elif isinstance(segment, TextSegment):
delays["s2tt"].append(curr_delay)
prediction_lists["s2tt"].append(segment.content)
print(curr_delay, segment.content)
if output_segments.finished:
reset_states(system, system_states)
if input_segment.finished:
# an assumption of SimulEval agents -
# once source_finished=True, generate until output translation is finished
break
return delays, prediction_lists, speech_durations, target_sample_rate
def get_s2st_delayed_targets(delays, target_sample_rate, prediction_lists, speech_durations):
# get calculate intervals + durations for s2st
intervals = []
start = prev_end = prediction_offset = delays["s2st"][0]
target_samples = [0.0] * int(target_sample_rate * prediction_offset / 1000)
for i, delay in enumerate(delays["s2st"]):
start = max(prev_end, delay)
if start > prev_end:
# Wait source speech, add discontinuity with silence
target_samples += [0.0] * int(
target_sample_rate * (start - prev_end) / 1000
)
target_samples += prediction_lists["s2st"][i]
duration = speech_durations[i]
prev_end = start + duration
intervals.append([start, duration])
return target_samples, intervals