diff --git a/src/pipecat/services/azure.py b/src/pipecat/services/azure.py index 2d43d9a8c..17ec34051 100644 --- a/src/pipecat/services/azure.py +++ b/src/pipecat/services/azure.py @@ -171,12 +171,12 @@ class AzureImageGenServiceREST(ImageGenService): def __init__( self, *, - aiohttp_session: aiohttp.ClientSession, image_size: str, api_key: str, endpoint: str, model: str, api_version="2023-06-01-preview", + aiohttp_session: aiohttp.ClientSession | None = None, ): super().__init__() @@ -184,8 +184,14 @@ def __init__( self._azure_endpoint = endpoint self._api_version = api_version self._model = model - self._aiohttp_session = aiohttp_session self._image_size = image_size + self._aiohttp_session = aiohttp_session or aiohttp.ClientSession() + self._close_aiohttp_session = aiohttp_session is None + + async def cleanup(self): + await super().cleanup() + if self._close_aiohttp_session: + await self._aiohttp_session.close() async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}" diff --git a/src/pipecat/services/deepgram.py b/src/pipecat/services/deepgram.py index 1c20df45c..1cc69ca5a 100644 --- a/src/pipecat/services/deepgram.py +++ b/src/pipecat/services/deepgram.py @@ -45,25 +45,31 @@ class DeepgramTTSService(TTSService): def __init__( self, *, - aiohttp_session: aiohttp.ClientSession, api_key: str, voice: str = "aura-helios-en", base_url: str = "https://api.deepgram.com/v1/speak", sample_rate: int = 16000, encoding: str = "linear16", + aiohttp_session: aiohttp.ClientSession | None = None, **kwargs): super().__init__(**kwargs) self._voice = voice self._api_key = api_key - self._aiohttp_session = aiohttp_session self._base_url = base_url self._sample_rate = sample_rate self._encoding = encoding + self._aiohttp_session = aiohttp_session or aiohttp.ClientSession() + self._close_aiohttp_session = aiohttp_session is None def can_generate_metrics(self) -> bool: return True + async def cleanup(self): + await super().cleanup() + if self._close_aiohttp_session: + await self._aiohttp_session.close() + async def set_voice(self, voice: str): logger.debug(f"Switching TTS voice to: [{voice}]") self._voice = voice diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index 1bf0fe6ca..b81773346 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -19,21 +19,27 @@ class ElevenLabsTTSService(TTSService): def __init__( self, *, - aiohttp_session: aiohttp.ClientSession, api_key: str, voice_id: str, model: str = "eleven_turbo_v2", + aiohttp_session: aiohttp.ClientSession | None = None, **kwargs): super().__init__(**kwargs) self._api_key = api_key self._voice_id = voice_id - self._aiohttp_session = aiohttp_session self._model = model + self._aiohttp_session = aiohttp_session or aiohttp.ClientSession() + self._close_aiohttp_session = aiohttp_session is None def can_generate_metrics(self) -> bool: return True + async def cleanup(self): + await super().cleanup() + if self._close_aiohttp_session: + await self._aiohttp_session.close() + async def set_voice(self, voice: str): logger.debug(f"Switching TTS voice to: [{voice}]") self._voice_id = voice diff --git a/src/pipecat/services/fal.py b/src/pipecat/services/fal.py index cceddd6ff..f4811e4d5 100644 --- a/src/pipecat/services/fal.py +++ b/src/pipecat/services/fal.py @@ -39,18 +39,24 @@ class InputParams(BaseModel): def __init__( self, *, - aiohttp_session: aiohttp.ClientSession, params: InputParams, model: str = "fal-ai/fast-sdxl", key: str | None = None, + aiohttp_session: aiohttp.ClientSession | None = None, ): super().__init__() self._model = model self._params = params - self._aiohttp_session = aiohttp_session + self._aiohttp_session = aiohttp_session or aiohttp.ClientSession() + self._close_aiohttp_session = aiohttp_session is None if key: os.environ["FAL_KEY"] = key + async def cleanup(self): + await super().cleanup() + if self._close_aiohttp_session: + await self._aiohttp_session.close() + async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating image from prompt: {prompt}") diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 3e1a6effc..574021273 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -253,16 +253,22 @@ class OpenAIImageGenService(ImageGenService): def __init__( self, *, - image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], - aiohttp_session: aiohttp.ClientSession, api_key: str, model: str = "dall-e-3", + image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], + aiohttp_session: aiohttp.ClientSession | None = None, ): super().__init__() self._model = model self._image_size = image_size self._client = AsyncOpenAI(api_key=api_key) - self._aiohttp_session = aiohttp_session + self._aiohttp_session = aiohttp_session or aiohttp.ClientSession() + self._close_aiohttp_session = aiohttp_session is None + + async def cleanup(self): + await super().cleanup() + if self._close_aiohttp_session: + await self._aiohttp_session.close() async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating image from prompt: {prompt}") diff --git a/src/pipecat/services/xtts.py b/src/pipecat/services/xtts.py index a17277b88..396bc2919 100644 --- a/src/pipecat/services/xtts.py +++ b/src/pipecat/services/xtts.py @@ -38,22 +38,28 @@ class XTTSService(TTSService): def __init__( self, *, - aiohttp_session: aiohttp.ClientSession, voice_id: str, language: str, base_url: str, + aiohttp_session: aiohttp.ClientSession | None = None, **kwargs): super().__init__(**kwargs) self._voice_id = voice_id self._language = language self._base_url = base_url - self._aiohttp_session = aiohttp_session self._studio_speakers = requests.get(self._base_url + "/studio_speakers").json() + self._aiohttp_session = aiohttp_session or aiohttp.ClientSession() + self._close_aiohttp_session = aiohttp_session is None def can_generate_metrics(self) -> bool: return True + async def cleanup(self): + await super().cleanup() + if self._close_aiohttp_session: + await self._aiohttp_session.close() + async def set_voice(self, voice: str): logger.debug(f"Switching TTS voice to: [{voice}]") self._voice_id = voice