storytelling-backup / pipecat /pipeline /parallel_pipeline.py
lucy1118's picture
Upload 78 files
8d7f55c verified
raw
history blame
5.44 kB
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
from itertools import chain
from typing import List
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame
from loguru import logger
class Source(FrameProcessor):
def __init__(self, upstream_queue: asyncio.Queue):
super().__init__()
self._up_queue = upstream_queue
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
match direction:
case FrameDirection.UPSTREAM:
await self._up_queue.put(frame)
case FrameDirection.DOWNSTREAM:
await self.push_frame(frame, direction)
class Sink(FrameProcessor):
def __init__(self, downstream_queue: asyncio.Queue):
super().__init__()
self._down_queue = downstream_queue
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
match direction:
case FrameDirection.UPSTREAM:
await self.push_frame(frame, direction)
case FrameDirection.DOWNSTREAM:
await self._down_queue.put(frame)
class ParallelPipeline(BasePipeline):
def __init__(self, *args):
super().__init__()
if len(args) == 0:
raise Exception(f"ParallelPipeline needs at least one argument")
self._sources = []
self._sinks = []
self._up_queue = asyncio.Queue()
self._down_queue = asyncio.Queue()
self._up_task: asyncio.Task | None = None
self._down_task: asyncio.Task | None = None
self._pipelines = []
logger.debug(f"Creating {self} pipelines")
for processors in args:
if not isinstance(processors, list):
raise TypeError(f"ParallelPipeline argument {processors} is not a list")
# We will add a source before the pipeline and a sink after.
source = Source(self._up_queue)
sink = Sink(self._down_queue)
self._sources.append(source)
self._sinks.append(sink)
# Create pipeline
pipeline = Pipeline(processors)
source.link(pipeline)
pipeline.link(sink)
self._pipelines.append(pipeline)
logger.debug(f"Finished creating {self} pipelines")
#
# BasePipeline
#
def processors_with_metrics(self) -> List[FrameProcessor]:
return list(chain.from_iterable(p.processors_with_metrics() for p in self._pipelines))
#
# Frame processor
#
async def cleanup(self):
await asyncio.gather(*[p.cleanup() for p in self._pipelines])
async def _start_tasks(self):
loop = self.get_event_loop()
self._up_task = loop.create_task(self._process_up_queue())
self._down_task = loop.create_task(self._process_down_queue())
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
await self._start_tasks()
if direction == FrameDirection.UPSTREAM:
# If we get an upstream frame we process it in each sink.
await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sinks])
elif direction == FrameDirection.DOWNSTREAM:
# If we get a downstream frame we process it in each source.
# TODO(aleix): We are creating task for each frame. For real-time
# video/audio this might be too slow. We should use an already
# created task instead.
await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sources])
# If we get an EndFrame we stop our queue processing tasks and wait on
# all the pipelines to finish.
if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame):
# Use None to indicate when queues should be done processing.
await self._up_queue.put(None)
await self._down_queue.put(None)
if self._up_task:
await self._up_task
if self._down_task:
await self._down_task
async def _process_up_queue(self):
running = True
seen_ids = set()
while running:
frame = await self._up_queue.get()
if frame and frame.id not in seen_ids:
await self.push_frame(frame, FrameDirection.UPSTREAM)
seen_ids.add(frame.id)
running = frame is not None
self._up_queue.task_done()
async def _process_down_queue(self):
running = True
seen_ids = set()
while running:
frame = await self._down_queue.get()
if frame and frame.id not in seen_ids:
await self.push_frame(frame, FrameDirection.DOWNSTREAM)
seen_ids.add(frame.id)
running = frame is not None
self._down_queue.task_done()