Skip to content

Commit

Permalink
services: fix infinite websocket-bases TTS services retries
Browse files Browse the repository at this point in the history
Fixes #871
  • Loading branch information
aconchillo committed Dec 16, 2024
1 parent 8e140b2 commit 653fbb7
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 99 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ All notable changes to **Pipecat** will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.0.51] - 2024-12-16

### Fixed

- Fixed an issue in websocket-based TTS services that was causing infinite
reconnections (Cartesia, ElevenLabs, PlayHT and LMNT).

## [0.0.50] - 2024-12-11

### Added
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"pydantic~=2.8.2",
"pyloudnorm~=0.1.1",
"resampy~=0.4.3",
"tenacity~=9.0.0"
]

[project.urls]
Expand All @@ -55,7 +56,7 @@ gstreamer = [ "pygobject~=3.48.2" ]
fireworks = [ "openai~=1.50.2" ]
krisp = [ "pipecat-ai-krisp~=0.3.0" ]
langchain = [ "langchain~=0.2.14", "langchain-community~=0.2.12", "langchain-openai~=0.1.20" ]
livekit = [ "livekit~=0.17.5", "livekit-api~=0.7.1", "tenacity~=8.5.0" ]
livekit = [ "livekit~=0.17.5", "livekit-api~=0.7.1" ]
lmnt = [ "lmnt~=1.1.4" ]
local = [ "pyaudio~=0.2.14" ]
moondream = [ "einops~=0.8.0", "timm~=1.0.8", "transformers~=4.44.0" ]
Expand Down
94 changes: 54 additions & 40 deletions src/pipecat/services/cartesia.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from loguru import logger
from pydantic import BaseModel
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential


from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
Expand Down Expand Up @@ -239,52 +241,64 @@ async def flush_audio(self):
msg = self._build_msg(text="", continue_transcript=False)
await self._websocket.send(msg)

async def _receive_messages(self):
async for message in self._get_websocket():
msg = json.loads(message)
if not msg or msg["context_id"] != self._context_id:
continue
if msg["type"] == "done":
await self.stop_ttfb_metrics()
# Unset _context_id but not the _context_id_start_timestamp
# because we are likely still playing out audio and need the
# timestamp to set send context frames.
self._context_id = None
await self.add_word_timestamps(
[("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)]
)
elif msg["type"] == "timestamps":
await self.add_word_timestamps(
list(zip(msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]))
)
elif msg["type"] == "chunk":
await self.stop_ttfb_metrics()
self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
else:
logger.error(f"{self} error, unknown message type: {msg}")

async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_websocket()
await self._connect_websocket()

async def _receive_task_handler(self):
while True:
try:
async for message in self._get_websocket():
msg = json.loads(message)
if not msg or msg["context_id"] != self._context_id:
continue
if msg["type"] == "done":
await self.stop_ttfb_metrics()
# Unset _context_id but not the _context_id_start_timestamp
# because we are likely still playing out audio and need the
# timestamp to set send context frames.
self._context_id = None
await self.add_word_timestamps(
[("TTSStoppedFrame", 0), ("LLMFullResponseEndFrame", 0), ("Reset", 0)]
)
elif msg["type"] == "timestamps":
await self.add_word_timestamps(
list(
zip(
msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]
)
)
)
elif msg["type"] == "chunk":
await self.stop_ttfb_metrics()
self.start_word_timestamps()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["data"]),
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
elif msg["type"] == "error":
logger.error(f"{self} error: {msg}")
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
else:
logger.error(f"{self} error, unknown message type: {msg}")
async for attempt in AsyncRetrying(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
before_sleep=self._reconnect_websocket,
reraise=True,
):
with attempt:
await self._receive_messages()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self._disconnect_websocket()
await self._connect_websocket()
message = f"{self} error receiving messages: {e}"
logger.error(message)
await self.push_error(ErrorFrame(message, fatal=True))
break

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
Expand Down
50 changes: 34 additions & 16 deletions src/pipecat/services/elevenlabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@

from loguru import logger
from pydantic import BaseModel, model_validator
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential

from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
StartFrame,
Expand Down Expand Up @@ -348,28 +350,44 @@ async def _disconnect_websocket(self):
except Exception as e:
logger.error(f"{self} error closing websocket: {e}")

async def _receive_messages(self):
async for message in self._websocket:
msg = json.loads(message)
if msg.get("audio"):
await self.stop_ttfb_metrics()
self.start_word_timestamps()

audio = base64.b64decode(msg["audio"])
frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1)
await self.push_frame(frame)
if msg.get("alignment"):
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
await self.add_word_timestamps(word_times)
self._cumulative_time = word_times[-1][1]

async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_websocket()
await self._connect_websocket()

async def _receive_task_handler(self):
while True:
try:
async for message in self._websocket:
msg = json.loads(message)
if msg.get("audio"):
await self.stop_ttfb_metrics()
self.start_word_timestamps()

audio = base64.b64decode(msg["audio"])
frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1)
await self.push_frame(frame)
if msg.get("alignment"):
word_times = calculate_word_times(msg["alignment"], self._cumulative_time)
await self.add_word_timestamps(word_times)
self._cumulative_time = word_times[-1][1]
async for attempt in AsyncRetrying(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
before_sleep=self._reconnect_websocket,
reraise=True,
):
with attempt:
await self._receive_messages()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self._disconnect_websocket()
await self._connect_websocket()
message = f"{self} error receiving messages: {e}"
logger.error(message)
await self.push_error(ErrorFrame(message, fatal=True))
break

async def _keepalive_task_handler(self):
while True:
Expand Down
55 changes: 36 additions & 19 deletions src/pipecat/services/lmnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import AsyncGenerator

from loguru import logger
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential

from pipecat.frames.frames import (
CancelFrame,
Expand Down Expand Up @@ -159,31 +160,47 @@ async def _disconnect_lmnt(self):
except Exception as e:
logger.error(f"{self} error closing connection: {e}")

async def _receive_messages(self):
async for msg in self._connection:
if "error" in msg:
logger.error(f'{self} error: {msg["error"]}')
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
elif "audio" in msg:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=msg["audio"],
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
else:
logger.error(f"{self}: LMNT error, unknown message type: {msg}")

async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_lmnt()
await self._connect_lmnt()

async def _receive_task_handler(self):
while True:
try:
async for msg in self._connection:
if "error" in msg:
logger.error(f'{self} error: {msg["error"]}')
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
elif "audio" in msg:
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=msg["audio"],
sample_rate=self._settings["output_format"]["sample_rate"],
num_channels=1,
)
await self.push_frame(frame)
else:
logger.error(f"{self}: LMNT error, unknown message type: {msg}")
async for attempt in AsyncRetrying(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
before_sleep=self._reconnect_websocket,
reraise=True,
):
with attempt:
await self._receive_messages()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} exception: {e}")
await self._disconnect_lmnt()
await self._connect_lmnt()
message = f"{self} error receiving messages: {e}"
logger.error(message)
await self.push_error(ErrorFrame(message, fatal=True))
break

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")
Expand Down
63 changes: 40 additions & 23 deletions src/pipecat/services/playht.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import websockets
from loguru import logger
from pydantic import BaseModel
from tenacity import AsyncRetrying, RetryCallState, stop_after_attempt, wait_exponential

from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
Expand Down Expand Up @@ -217,35 +218,51 @@ async def _handle_interruption(self, frame: StartInterruptionFrame, direction: F
await self.stop_all_metrics()
self._request_id = None

async def _receive_messages(self):
async for message in self._get_websocket():
if isinstance(message, bytes):
# Skip the WAV header message
if message.startswith(b"RIFF"):
continue
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1)
await self.push_frame(frame)
else:
logger.debug(f"Received text message: {message}")
try:
msg = json.loads(message)
if "request_id" in msg and msg["request_id"] == self._request_id:
await self.push_frame(TTSStoppedFrame())
self._request_id = None
elif "error" in msg:
logger.error(f"{self} error: {msg}")
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")

async def _reconnect_websocket(self, retry_state: RetryCallState):
logger.warning(f"{self} reconnecting (attempt: {retry_state.attempt_number})")
await self._disconnect_websocket()
await self._connect_websocket()

async def _receive_task_handler(self):
while True:
try:
async for message in self._get_websocket():
if isinstance(message, bytes):
# Skip the WAV header message
if message.startswith(b"RIFF"):
continue
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(message, self._settings["sample_rate"], 1)
await self.push_frame(frame)
else:
logger.debug(f"Received text message: {message}")
try:
msg = json.loads(message)
if "request_id" in msg and msg["request_id"] == self._request_id:
await self.push_frame(TTSStoppedFrame())
self._request_id = None
elif "error" in msg:
logger.error(f"{self} error: {msg}")
await self.push_error(ErrorFrame(f'{self} error: {msg["error"]}'))
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")
async for attempt in AsyncRetrying(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
before_sleep=self._reconnect_websocket,
reraise=True,
):
with attempt:
await self._receive_messages()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"{self} exception in receive task: {e}")
await self._disconnect_websocket()
await self._connect_websocket()
message = f"{self} error receiving messages: {e}"
logger.error(message)
await self.push_error(ErrorFrame(message, fatal=True))
break

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
Expand Down

0 comments on commit 653fbb7

Please sign in to comment.