Spaces:
Paused
Paused
# base seamless imports | |
# --------------------------------- | |
import io | |
import json | |
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
import mmap | |
import numpy as np | |
import soundfile | |
import torchaudio | |
import torch | |
from pydub import AudioSegment | |
# --------------------------------- | |
# seamless-streaming specific imports | |
# --------------------------------- | |
import math | |
from simuleval.data.segments import SpeechSegment, EmptySegment | |
from seamless_communication.streaming.agents.seamless_streaming_s2st import ( | |
SeamlessStreamingS2STVADAgent, | |
) | |
from simuleval.utils.arguments import cli_argument_list | |
from simuleval import options | |
from typing import Union, List | |
from simuleval.data.segments import Segment, TextSegment | |
from simuleval.agents.pipeline import TreeAgentPipeline | |
from simuleval.agents.states import AgentStates | |
# --------------------------------- | |
# seamless setup | |
# source: https://colab.research.google.com/github/kauterry/seamless_communication/blob/main/Seamless_Tutorial.ipynb? | |
SAMPLE_RATE = 16000 | |
# PM - THis class is used to simulate the audio frontend in the seamless streaming pipeline | |
# need to replace this with the actual audio frontend | |
# TODO: replacement class that takes in PCM-16 bytes and returns SpeechSegment | |
class AudioFrontEnd: | |
def __init__(self, wav_file, segment_size) -> None: | |
self.samples, self.sample_rate = soundfile.read(wav_file) | |
print(self.sample_rate, "sample rate") | |
assert self.sample_rate == SAMPLE_RATE | |
# print(len(self.samples), self.samples[:100]) | |
self.samples = self.samples # .tolist() | |
self.segment_size = segment_size | |
self.step = 0 | |
def send_segment(self): | |
""" | |
This is the front-end logic in simuleval instance.py | |
""" | |
num_samples = math.ceil(self.segment_size / 1000 * self.sample_rate) | |
if self.step < len(self.samples): | |
if self.step + num_samples >= len(self.samples): | |
samples = self.samples[self.step :] | |
is_finished = True | |
else: | |
samples = self.samples[self.step : self.step + num_samples] | |
is_finished = False | |
self.samples = self.samples[self.step:] | |
self.step = min(self.step + num_samples, len(self.samples)) | |
segment = SpeechSegment( | |
content=samples, | |
sample_rate=self.sample_rate, | |
finished=is_finished, | |
) | |
else: | |
# Finish reading this audio | |
segment = EmptySegment( | |
finished=True, | |
) | |
self.step = 0 | |
self.samples = [] | |
return segment | |
# samples = self.samples[:num_samples] | |
# self.samples = self.samples[num_samples:] | |
# segment = SpeechSegment( | |
# content=samples, | |
# sample_rate=self.sample_rate, | |
# finished=False, | |
# ) | |
def add_segments(self, wav): | |
new_samples, _ = soundfile.read(wav) | |
self.samples = np.concatenate((self.samples, new_samples)) | |
class OutputSegments: | |
def __init__(self, segments: Union[List[Segment], Segment]): | |
if isinstance(segments, Segment): | |
segments = [segments] | |
self.segments: List[Segment] = [s for s in segments] | |
def is_empty(self): | |
return all(segment.is_empty for segment in self.segments) | |
def finished(self): | |
return all(segment.finished for segment in self.segments) | |
def get_audiosegment(samples, sr): | |
b = io.BytesIO() | |
soundfile.write(b, samples, samplerate=sr, format="wav") | |
b.seek(0) | |
return AudioSegment.from_file(b) | |
def reset_states(system, states): | |
if isinstance(system, TreeAgentPipeline): | |
states_iter = states.values() | |
else: | |
states_iter = states | |
for state in states_iter: | |
state.reset() | |
def get_states_root(system, states) -> AgentStates: | |
if isinstance(system, TreeAgentPipeline): | |
# self.states is a dict | |
return states[system.source_module] | |
else: | |
# self.states is a list | |
return system.states[0] | |
def build_streaming_system(model_configs, agent_class): | |
parser = options.general_parser() | |
parser.add_argument("-f", "--f", help="a dummy argument to fool ipython", default="1") | |
agent_class.add_args(parser) | |
args, _ = parser.parse_known_args(cli_argument_list(model_configs)) | |
system = agent_class.from_args(args) | |
return system | |
def run_streaming_inference(system, audio_frontend, system_states, tgt_lang): | |
# NOTE: Here for visualization, we calculate delays offset from audio | |
# *BEFORE* VAD segmentation. | |
# In contrast for SimulEval evaluation, we assume audios are pre-segmented, | |
# and Average Lagging, End Offset metrics are based on those pre-segmented audios. | |
# Thus, delays here are *NOT* comparable to SimulEval per-segment delays | |
delays = {"s2st": [], "s2tt": []} | |
prediction_lists = {"s2st": [], "s2tt": []} | |
speech_durations = [] | |
curr_delay = 0 | |
target_sample_rate = None | |
while True: | |
input_segment = audio_frontend.send_segment() | |
input_segment.tgt_lang = tgt_lang | |
curr_delay += len(input_segment.content) / SAMPLE_RATE * 1000 | |
if input_segment.finished: | |
# a hack, we expect a real stream to end with silence | |
get_states_root(system, system_states).source_finished = True | |
# Translation happens here | |
if isinstance(input_segment, EmptySegment): | |
return None, None, None, None | |
output_segments = OutputSegments(system.pushpop(input_segment, system_states)) | |
if not output_segments.is_empty: | |
for segment in output_segments.segments: | |
# NOTE: another difference from SimulEval evaluation - | |
# delays are accumulated per-token | |
if isinstance(segment, SpeechSegment): | |
pred_duration = 1000 * len(segment.content) / segment.sample_rate | |
speech_durations.append(pred_duration) | |
delays["s2st"].append(curr_delay) | |
prediction_lists["s2st"].append(segment.content) | |
target_sample_rate = segment.sample_rate | |
elif isinstance(segment, TextSegment): | |
delays["s2tt"].append(curr_delay) | |
prediction_lists["s2tt"].append(segment.content) | |
print(curr_delay, segment.content) | |
if output_segments.finished: | |
reset_states(system, system_states) | |
if input_segment.finished: | |
# an assumption of SimulEval agents - | |
# once source_finished=True, generate until output translation is finished | |
break | |
return delays, prediction_lists, speech_durations, target_sample_rate | |
def get_s2st_delayed_targets(delays, target_sample_rate, prediction_lists, speech_durations): | |
# get calculate intervals + durations for s2st | |
intervals = [] | |
start = prev_end = prediction_offset = delays["s2st"][0] | |
target_samples = [0.0] * int(target_sample_rate * prediction_offset / 1000) | |
for i, delay in enumerate(delays["s2st"]): | |
start = max(prev_end, delay) | |
if start > prev_end: | |
# Wait source speech, add discontinuity with silence | |
target_samples += [0.0] * int( | |
target_sample_rate * (start - prev_end) / 1000 | |
) | |
target_samples += prediction_lists["s2st"][i] | |
duration = speech_durations[i] | |
prev_end = start + duration | |
intervals.append([start, duration]) | |
return target_samples, intervals | |