File size: 5,441 Bytes
8d7f55c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
from itertools import chain
from typing import List
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame
from loguru import logger
class Source(FrameProcessor):
def __init__(self, upstream_queue: asyncio.Queue):
super().__init__()
self._up_queue = upstream_queue
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
match direction:
case FrameDirection.UPSTREAM:
await self._up_queue.put(frame)
case FrameDirection.DOWNSTREAM:
await self.push_frame(frame, direction)
class Sink(FrameProcessor):
def __init__(self, downstream_queue: asyncio.Queue):
super().__init__()
self._down_queue = downstream_queue
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
match direction:
case FrameDirection.UPSTREAM:
await self.push_frame(frame, direction)
case FrameDirection.DOWNSTREAM:
await self._down_queue.put(frame)
class ParallelPipeline(BasePipeline):
def __init__(self, *args):
super().__init__()
if len(args) == 0:
raise Exception(f"ParallelPipeline needs at least one argument")
self._sources = []
self._sinks = []
self._up_queue = asyncio.Queue()
self._down_queue = asyncio.Queue()
self._up_task: asyncio.Task | None = None
self._down_task: asyncio.Task | None = None
self._pipelines = []
logger.debug(f"Creating {self} pipelines")
for processors in args:
if not isinstance(processors, list):
raise TypeError(f"ParallelPipeline argument {processors} is not a list")
# We will add a source before the pipeline and a sink after.
source = Source(self._up_queue)
sink = Sink(self._down_queue)
self._sources.append(source)
self._sinks.append(sink)
# Create pipeline
pipeline = Pipeline(processors)
source.link(pipeline)
pipeline.link(sink)
self._pipelines.append(pipeline)
logger.debug(f"Finished creating {self} pipelines")
#
# BasePipeline
#
def processors_with_metrics(self) -> List[FrameProcessor]:
return list(chain.from_iterable(p.processors_with_metrics() for p in self._pipelines))
#
# Frame processor
#
async def cleanup(self):
await asyncio.gather(*[p.cleanup() for p in self._pipelines])
async def _start_tasks(self):
loop = self.get_event_loop()
self._up_task = loop.create_task(self._process_up_queue())
self._down_task = loop.create_task(self._process_down_queue())
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartFrame):
await self._start_tasks()
if direction == FrameDirection.UPSTREAM:
# If we get an upstream frame we process it in each sink.
await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sinks])
elif direction == FrameDirection.DOWNSTREAM:
# If we get a downstream frame we process it in each source.
# TODO(aleix): We are creating task for each frame. For real-time
# video/audio this might be too slow. We should use an already
# created task instead.
await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sources])
# If we get an EndFrame we stop our queue processing tasks and wait on
# all the pipelines to finish.
if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame):
# Use None to indicate when queues should be done processing.
await self._up_queue.put(None)
await self._down_queue.put(None)
if self._up_task:
await self._up_task
if self._down_task:
await self._down_task
async def _process_up_queue(self):
running = True
seen_ids = set()
while running:
frame = await self._up_queue.get()
if frame and frame.id not in seen_ids:
await self.push_frame(frame, FrameDirection.UPSTREAM)
seen_ids.add(frame.id)
running = frame is not None
self._up_queue.task_done()
async def _process_down_queue(self):
running = True
seen_ids = set()
while running:
frame = await self._down_queue.get()
if frame and frame.id not in seen_ids:
await self.push_frame(frame, FrameDirection.DOWNSTREAM)
seen_ids.add(frame.id)
running = frame is not None
self._down_queue.task_done()
|