From cc3874fcafae6dde306ee484282433fd7a3cbbae Mon Sep 17 00:00:00 2001 From: vipyne Date: Wed, 4 Dec 2024 12:27:18 -0600 Subject: [PATCH] updates to be squashed --- src/pipecat/services/fastpitch.py | 5 +- src/pipecat/services/parakeet.py | 159 ++++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 3 deletions(-) create mode 100644 src/pipecat/services/parakeet.py diff --git a/src/pipecat/services/fastpitch.py b/src/pipecat/services/fastpitch.py index 1207ebb83..69c3a02c2 100644 --- a/src/pipecat/services/fastpitch.py +++ b/src/pipecat/services/fastpitch.py @@ -35,9 +35,7 @@ try: import websockets - import riva.client - from riva.client.argparse_utils import add_connection_argparse_parameters except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error( @@ -92,7 +90,7 @@ def __init__( super().__init__( aggregate_sentences=True, push_text_frames=False, - sample_rate=sample_rate, + sample_rate=sample_rate_hz, **kwargs, ) @@ -144,6 +142,7 @@ async def cancel(self, frame: CancelFrame): async def _connect(self): try: + # borked self._websocket = await websockets.connect( f"{self._url}?api_key={self._api_key}&fastpitch_version={self._fastpitch_version}" ) diff --git a/src/pipecat/services/parakeet.py b/src/pipecat/services/parakeet.py new file mode 100644 index 000000000..540f5bc4b --- /dev/null +++ b/src/pipecat/services/parakeet.py @@ -0,0 +1,159 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +from typing import AsyncGenerator + +from loguru import logger + +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + InterimTranscriptionFrame, + StartFrame, + TranscriptionFrame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.ai_services import STTService, TTSService +from pipecat.transcriptions.language import Language +from pipecat.utils.time import time_now_iso8601 + +from pipecat.audio import audio_io + +try: + + import riva.client + from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters + +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Parakeet, you need to `pip install pipecat-ai[parakeet]`. Also, set `NVIDIA_API_KEY` environment variable." + ) + raise Exception(f"Missing module: {e}") + +# TODO: maybe this becomes nvidia.py or nvidia-riva.py and we put the models in here ? +# idk, breaks the established pattern in some ways, but keeps it in others... +# TODO: finish this if that is what we want to do +#### class FastpitchTTSService(TTSService): + + +class ParakeetSTTService(STTService): + def __init__( + self, + *, + api_key: str, + url: str = "", + live_options: LiveOptions = None, + **kwargs, + ): + super().__init__(**kwargs) + default_options = LiveOptions( + encoding="linear16", + language=Language.EN, + model="nova-2-general", + sample_rate=16000, + channels=1, + interim_results=True, + smart_format=True, + punctuate=True, + profanity_filter=True, + vad_events=False, + ) + + merged_options = default_options + if live_options: + merged_options = LiveOptions(**{**default_options.to_dict(), **live_options.to_dict()}) + self._settings = merged_options.to_dict() + + self._client = ParakeetClient( + api_key, + config=ParakeetClientOptions( + url=url, + options={"keepalive": "true"}, # verbose=logging.DEBUG + ), + ) + self._connection: AsyncListenWebSocketClient = self._client.listen.asyncwebsocket.v("1") + self._connection.on(LiveTranscriptionEvents.Transcript, self._on_message) + if self.vad_enabled: + self._connection.on(LiveTranscriptionEvents.SpeechStarted, self._on_speech_started) + + @property + def vad_enabled(self): + return self._settings["vad_events"] + + def can_generate_metrics(self) -> bool: + return self.vad_enabled + + async def set_model(self, model: str): + await super().set_model(model) + logger.info(f"Switching STT model to: [{model}]") + self._settings["model"] = model + await self._disconnect() + await self._connect() + + async def set_language(self, language: Language): + logger.info(f"Switching STT language to: [{language}]") + self._settings["language"] = language + await self._disconnect() + await self._connect() + + async def start(self, frame: StartFrame): + await super().start(frame) + await self._connect() + + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._disconnect() + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + await self._connection.send(audio) + yield None + + async def _connect(self): + if await self._connection.start(self._settings): + logger.info(f"{self}: Connected to Parakeet") + else: + logger.error(f"{self}: Unable to connect to Parakeet") + + async def _disconnect(self): + if self._connection.is_connected: + await self._connection.finish() + logger.info(f"{self}: Disconnected from Parakeet") + + async def _on_speech_started(self, *args, **kwargs): + await self.start_ttfb_metrics() + await self.start_processing_metrics() + + async def _on_message(self, *args, **kwargs): + result: LiveResultResponse = kwargs["result"] + if len(result.channel.alternatives) == 0: + return + is_final = result.is_final + transcript = result.channel.alternatives[0].transcript + language = None + if result.channel.alternatives[0].languages: + language = result.channel.alternatives[0].languages[0] + language = Language(language) + if len(transcript) > 0: + await self.stop_ttfb_metrics() + if is_final: + await self.push_frame( + TranscriptionFrame(transcript, "", time_now_iso8601(), language) + ) + await self.stop_processing_metrics() + else: + await self.push_frame( + InterimTranscriptionFrame(transcript, "", time_now_iso8601(), language) + )