Skip to content

Commit

Permalink
Merge pull request #330 from pipecat-ai/aleix/stop-and-cancel-are-dif…
Browse files Browse the repository at this point in the history
…ferent

EndFrame tries to end gracefully CancelFrame cancels tasks
  • Loading branch information
aconchillo authored Jul 31, 2024
2 parents c466d34 + d60e99a commit 62a7a55
Show file tree
Hide file tree
Showing 12 changed files with 193 additions and 84 deletions.
2 changes: 1 addition & 1 deletion examples/foundational/06a-image-sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, speaking_path: str, waiting_path: str):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if not isinstance(frame, SystemFrame):
if not isinstance(frame, SystemFrame) and direction == FrameDirection.DOWNSTREAM:
await self.push_frame(ImageRawFrame(image=self._speaking_image_bytes, size=(1024, 1024), format=self._speaking_image_format))
await self.push_frame(frame)
await self.push_frame(ImageRawFrame(image=self._waiting_image_bytes, size=(1024, 1024), format=self._waiting_image_format))
Expand Down
54 changes: 44 additions & 10 deletions src/pipecat/processors/frameworks/rtvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from pipecat.frames.frames import (
BotInterruptionFrame,
CancelFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
LLMFullResponseEndFrame,
Expand Down Expand Up @@ -343,32 +345,64 @@ def setup_on_start(self, config: RTVIConfig | None, ctor_args: Dict[str, Any]):
self._ctor_args = ctor_args

async def update_config(self, config: RTVIConfig):
await self._handle_config_update(config)
if self._pipeline:
await self._handle_config_update(config)
self._config = config

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, SystemFrame):
# Specific system frames
if isinstance(frame, CancelFrame):
await self._cancel(frame)
await self.push_frame(frame, direction)
# All other system frames
elif isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
# Control frames
elif isinstance(frame, StartFrame):
await self._start(frame)
await self._internal_push_frame(frame, direction)
elif isinstance(frame, EndFrame):
# Push EndFrame before stop(), because stop() waits on the task to
# finish and the task finishes when EndFrame is processed.
await self._internal_push_frame(frame, direction)
await self._stop(frame)
# Other frames
else:
await self._frame_queue.put((frame, direction))

if isinstance(frame, StartFrame):
try:
await self._handle_pipeline_setup(frame, self._config)
except Exception as e:
await self._send_error(f"unable to setup RTVI pipeline: {e}")
await self._internal_push_frame(frame, direction)

async def cleanup(self):
if self._pipeline:
await self._pipeline.cleanup()

async def _start(self, frame: StartFrame):
try:
await self._handle_pipeline_setup(frame, self._config)
except Exception as e:
await self._send_error(f"unable to setup RTVI pipeline: {e}")

async def _stop(self, frame: EndFrame):
await self._frame_handler_task

async def _cancel(self, frame: CancelFrame):
self._frame_handler_task.cancel()
await self._frame_handler_task

async def _internal_push_frame(
self,
frame: Frame | None,
direction: FrameDirection | None = FrameDirection.DOWNSTREAM):
await self._frame_queue.put((frame, direction))

async def _frame_handler(self):
while True:
running = True
while running:
try:
(frame, direction) = await self._frame_queue.get()
await self._handle_frame(frame, direction)
self._frame_queue.task_done()
running = not isinstance(frame, EndFrame)
except asyncio.CancelledError:
break

Expand Down
11 changes: 7 additions & 4 deletions src/pipecat/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,17 @@ async def _append_audio(self, frame: AudioRawFrame):
await self.stop_processing_metrics()
(self._content, self._wave) = self._new_wave()

async def stop(self, frame: EndFrame):
self._wave.close()

async def cancel(self, frame: CancelFrame):
self._wave.close()

async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Processes a frame of audio data, either buffering or transcribing it."""
await super().process_frame(frame, direction)

if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame):
self._wave.close()
await self.push_frame(frame, direction)
elif isinstance(frame, AudioRawFrame):
if isinstance(frame, AudioRawFrame):
# In this service we accumulate audio internally and at the end we
# push a TextFrame. We don't really want to push audio frames down.
await self._append_audio(frame)
Expand Down
3 changes: 3 additions & 0 deletions src/pipecat/services/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,16 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
await self._push_queue.put((frame, direction))

async def start(self, frame: StartFrame):
await super().start(frame)
self._speech_recognizer.start_continuous_recognition_async()

async def stop(self, frame: EndFrame):
await super().stop(frame)
self._speech_recognizer.stop_continuous_recognition_async()
self._audio_stream.close()

async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
self._speech_recognizer.stop_continuous_recognition_async()
self._audio_stream.close()

Expand Down
29 changes: 19 additions & 10 deletions src/pipecat/services/cartesia.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from pipecat.processors.frame_processor import FrameDirection
from pipecat.frames.frames import (
CancelFrame,
Frame,
AudioRawFrame,
StartInterruptionFrame,
Expand Down Expand Up @@ -98,6 +99,10 @@ async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()

async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()

async def _connect(self):
try:
self._websocket = await websockets.connect(
Expand All @@ -111,6 +116,8 @@ async def _connect(self):

async def _disconnect(self):
try:
await self.stop_all_metrics()

if self._context_appending_task:
self._context_appending_task.cancel()
await self._context_appending_task
Expand All @@ -120,13 +127,12 @@ async def _disconnect(self):
await self._receive_task
self._receive_task = None
if self._websocket:
ws = self._websocket
await self._websocket.close()
self._websocket = None
await ws.close()

self._context_id = None
self._context_id_start_timestamp = None
self._timestamped_words_buffer = []
await self.stop_all_metrics()
except Exception as e:
logger.exception(f"{self} error closing websocket: {e}")

Expand All @@ -142,13 +148,13 @@ async def _receive_task_handler(self):
try:
async for message in self._websocket:
msg = json.loads(message)
# logger.debug(f"Received message: {msg['type']} {msg['context_id']}")
if not msg or msg["context_id"] != self._context_id:
continue
if msg["type"] == "done":
await self.stop_ttfb_metrics()
# unset _context_id but not the _context_id_start_timestamp because we are likely still
# playing out audio and need the timestamp to set send context frames
# Unset _context_id but not the _context_id_start_timestamp
# because we are likely still playing out audio and need the
# timestamp to set send context frames.
self._context_id = None
self._timestamped_words_buffer.append(("LLMFullResponseEndFrame", 0))
elif msg["type"] == "timestamps":
Expand All @@ -166,6 +172,8 @@ async def _receive_task_handler(self):
num_channels=1
)
await self.push_frame(frame)
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"{self} exception: {e}")

Expand All @@ -176,15 +184,17 @@ async def _context_appending_task_handler(self):
if not self._context_id_start_timestamp:
continue
elapsed_seconds = time.time() - self._context_id_start_timestamp
# pop all words from self._timestamped_words_buffer that are older than the
# elapsed time and print a message about them to the console
# Pop all words from self._timestamped_words_buffer that are
# older than the elapsed time and print a message about them to
# the console.
while self._timestamped_words_buffer and self._timestamped_words_buffer[0][1] <= elapsed_seconds:
word, timestamp = self._timestamped_words_buffer.pop(0)
if word == "LLMFullResponseEndFrame" and timestamp == 0:
await self.push_frame(LLMFullResponseEndFrame())
continue
# print(f"Word '{word}' with timestamp {timestamp:.2f}s has been spoken.")
await self.push_frame(TextFrame(word))
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(f"{self} exception: {e}")

Expand Down Expand Up @@ -212,7 +222,6 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"language": self._language,
"add_timestamps": True,
}
# logger.debug(f"SENDING MESSAGE {json.dumps(msg)}")
try:
await self._websocket.send(json.dumps(msg))
except Exception as e:
Expand Down
3 changes: 3 additions & 0 deletions src/pipecat/services/deepgram.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,18 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
await self.queue_frame(frame, direction)

async def start(self, frame: StartFrame):
await super().start(frame)
if await self._connection.start(self._live_options):
logger.debug(f"{self}: Connected to Deepgram")
else:
logger.error(f"{self}: Unable to connect to Deepgram")

async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._connection.finish()

async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._connection.finish()

async def _on_message(self, *args, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions src/pipecat/services/gladia.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,17 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
await self.queue_frame(frame, direction)

async def start(self, frame: StartFrame):
await super().start(frame)
self._websocket = await websockets.connect(self._url)
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
await self._setup_gladia()

async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._websocket.close()

async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._websocket.close()

async def _setup_gladia(self):
Expand Down
33 changes: 23 additions & 10 deletions src/pipecat/transports/base_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,26 @@ async def start(self, frame: StartFrame):
self._audio_in_queue = asyncio.Queue()
self._audio_task = self.get_event_loop().create_task(self._audio_task_handler())

async def stop(self):
# Wait for the task to finish.
async def stop(self, frame: EndFrame):
# Cancel and wait for the audio input task to finish.
if self._params.audio_in_enabled or self._params.vad_enabled:
self._audio_task.cancel()
await self._audio_task

# Wait for the push frame task to finish. It will finish when the
# EndFrame is actually processed.
await self._push_frame_task

async def cancel(self, frame: CancelFrame):
# Cancel all the tasks and wait for them to finish.

if self._params.audio_in_enabled or self._params.vad_enabled:
self._audio_task.cancel()
await self._audio_task

self._push_frame_task.cancel()
await self._push_frame_task

def vad_analyzer(self) -> VADAnalyzer | None:
return self._params.vad_analyzer

Expand All @@ -63,17 +77,12 @@ async def push_audio_frame(self, frame: AudioRawFrame):
# Frame processor
#

async def cleanup(self):
self._push_frame_task.cancel()
await self._push_frame_task

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

# Specific system frames
if isinstance(frame, CancelFrame):
await self.stop()
# We don't queue a CancelFrame since we want to stop ASAP.
await self.cancel(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, BotInterruptionFrame):
await self._handle_interruptions(frame, False)
Expand All @@ -89,8 +98,10 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
await self.start(frame)
await self._internal_push_frame(frame, direction)
elif isinstance(frame, EndFrame):
# Push EndFrame before stop(), because stop() waits on the task to
# finish and the task finishes when EndFrame is processed.
await self._internal_push_frame(frame, direction)
await self.stop()
await self.stop(frame)
# Other frames
else:
await self._internal_push_frame(frame, direction)
Expand All @@ -111,10 +122,12 @@ async def _internal_push_frame(
await self._push_queue.put((frame, direction))

async def _push_frame_task_handler(self):
while True:
running = True
while running:
try:
(frame, direction) = await self._push_queue.get()
await self.push_frame(frame, direction)
running = not isinstance(frame, EndFrame)
self._push_queue.task_done()
except asyncio.CancelledError:
break
Expand Down
Loading

0 comments on commit 62a7a55

Please sign in to comment.