Skip to content

Commit

Permalink
updates to be squashed
Browse files Browse the repository at this point in the history
  • Loading branch information
vipyne committed Dec 4, 2024
1 parent 7215ca2 commit cc3874f
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/pipecat/services/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@

try:
import websockets

import riva.client
from riva.client.argparse_utils import add_connection_argparse_parameters
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
Expand Down Expand Up @@ -92,7 +90,7 @@ def __init__(
super().__init__(
aggregate_sentences=True,
push_text_frames=False,
sample_rate=sample_rate,
sample_rate=sample_rate_hz,
**kwargs,
)

Expand Down Expand Up @@ -144,6 +142,7 @@ async def cancel(self, frame: CancelFrame):

async def _connect(self):
try:
# borked
self._websocket = await websockets.connect(
f"{self._url}?api_key={self._api_key}&fastpitch_version={self._fastpitch_version}"
)
Expand Down
159 changes: 159 additions & 0 deletions src/pipecat/services/parakeet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
from typing import AsyncGenerator

from loguru import logger

from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.ai_services import STTService, TTSService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601

from pipecat.audio import audio_io

try:

import riva.client
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters

except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Parakeet, you need to `pip install pipecat-ai[parakeet]`. Also, set `NVIDIA_API_KEY` environment variable."
)
raise Exception(f"Missing module: {e}")

# TODO: maybe this becomes nvidia.py or nvidia-riva.py and we put the models in here ?
# idk, breaks the established pattern in some ways, but keeps it in others...
# TODO: finish this if that is what we want to do
#### class FastpitchTTSService(TTSService):


class ParakeetSTTService(STTService):
def __init__(
self,
*,
api_key: str,
url: str = "",
live_options: LiveOptions = None,
**kwargs,
):
super().__init__(**kwargs)
default_options = LiveOptions(
encoding="linear16",
language=Language.EN,
model="nova-2-general",
sample_rate=16000,
channels=1,
interim_results=True,
smart_format=True,
punctuate=True,
profanity_filter=True,
vad_events=False,
)

merged_options = default_options
if live_options:
merged_options = LiveOptions(**{**default_options.to_dict(), **live_options.to_dict()})
self._settings = merged_options.to_dict()

self._client = ParakeetClient(
api_key,
config=ParakeetClientOptions(
url=url,
options={"keepalive": "true"}, # verbose=logging.DEBUG
),
)
self._connection: AsyncListenWebSocketClient = self._client.listen.asyncwebsocket.v("1")
self._connection.on(LiveTranscriptionEvents.Transcript, self._on_message)
if self.vad_enabled:
self._connection.on(LiveTranscriptionEvents.SpeechStarted, self._on_speech_started)

@property
def vad_enabled(self):
return self._settings["vad_events"]

def can_generate_metrics(self) -> bool:
return self.vad_enabled

async def set_model(self, model: str):
await super().set_model(model)
logger.info(f"Switching STT model to: [{model}]")
self._settings["model"] = model
await self._disconnect()
await self._connect()

async def set_language(self, language: Language):
logger.info(f"Switching STT language to: [{language}]")
self._settings["language"] = language
await self._disconnect()
await self._connect()

async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()

async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()

async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()

async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
await self._connection.send(audio)
yield None

async def _connect(self):
if await self._connection.start(self._settings):
logger.info(f"{self}: Connected to Parakeet")
else:
logger.error(f"{self}: Unable to connect to Parakeet")

async def _disconnect(self):
if self._connection.is_connected:
await self._connection.finish()
logger.info(f"{self}: Disconnected from Parakeet")

async def _on_speech_started(self, *args, **kwargs):
await self.start_ttfb_metrics()
await self.start_processing_metrics()

async def _on_message(self, *args, **kwargs):
result: LiveResultResponse = kwargs["result"]
if len(result.channel.alternatives) == 0:
return
is_final = result.is_final
transcript = result.channel.alternatives[0].transcript
language = None
if result.channel.alternatives[0].languages:
language = result.channel.alternatives[0].languages[0]
language = Language(language)
if len(transcript) > 0:
await self.stop_ttfb_metrics()
if is_final:
await self.push_frame(
TranscriptionFrame(transcript, "", time_now_iso8601(), language)
)
await self.stop_processing_metrics()
else:
await self.push_frame(
InterimTranscriptionFrame(transcript, "", time_now_iso8601(), language)
)

0 comments on commit cc3874f

Please sign in to comment.