Skip to content

Commit

Permalink
Merge pull request #333 from pipecat-ai/aleix/allow-internal-http-ses…
Browse files Browse the repository at this point in the history
…sions-rebased

services: allow internal http sessions if none is given
  • Loading branch information
aconchillo authored Aug 1, 2024
2 parents 62a7a55 + 3bfeb5b commit 3db7f6a
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 13 deletions.
10 changes: 8 additions & 2 deletions src/pipecat/services/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,21 +171,27 @@ class AzureImageGenServiceREST(ImageGenService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
image_size: str,
api_key: str,
endpoint: str,
model: str,
api_version="2023-06-01-preview",
aiohttp_session: aiohttp.ClientSession | None = None,
):
super().__init__()

self._api_key = api_key
self._azure_endpoint = endpoint
self._api_version = api_version
self._model = model
self._aiohttp_session = aiohttp_session
self._image_size = image_size
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None

async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()

async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}"
Expand Down
10 changes: 8 additions & 2 deletions src/pipecat/services/deepgram.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,31 @@ class DeepgramTTSService(TTSService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
api_key: str,
voice: str = "aura-helios-en",
base_url: str = "https://api.deepgram.com/v1/speak",
sample_rate: int = 16000,
encoding: str = "linear16",
aiohttp_session: aiohttp.ClientSession | None = None,
**kwargs):
super().__init__(**kwargs)

self._voice = voice
self._api_key = api_key
self._aiohttp_session = aiohttp_session
self._base_url = base_url
self._sample_rate = sample_rate
self._encoding = encoding
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None

def can_generate_metrics(self) -> bool:
return True

async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()

async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice = voice
Expand Down
10 changes: 8 additions & 2 deletions src/pipecat/services/elevenlabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,27 @@ class ElevenLabsTTSService(TTSService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
api_key: str,
voice_id: str,
model: str = "eleven_turbo_v2",
aiohttp_session: aiohttp.ClientSession | None = None,
**kwargs):
super().__init__(**kwargs)

self._api_key = api_key
self._voice_id = voice_id
self._aiohttp_session = aiohttp_session
self._model = model
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None

def can_generate_metrics(self) -> bool:
return True

async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()

async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice_id = voice
Expand Down
10 changes: 8 additions & 2 deletions src/pipecat/services/fal.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,24 @@ class InputParams(BaseModel):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
params: InputParams,
model: str = "fal-ai/fast-sdxl",
key: str | None = None,
aiohttp_session: aiohttp.ClientSession | None = None,
):
super().__init__()
self._model = model
self._params = params
self._aiohttp_session = aiohttp_session
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None
if key:
os.environ["FAL_KEY"] = key

async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()

async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating image from prompt: {prompt}")

Expand Down
12 changes: 9 additions & 3 deletions src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,16 +253,22 @@ class OpenAIImageGenService(ImageGenService):
def __init__(
self,
*,
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
aiohttp_session: aiohttp.ClientSession,
api_key: str,
model: str = "dall-e-3",
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
aiohttp_session: aiohttp.ClientSession | None = None,
):
super().__init__()
self._model = model
self._image_size = image_size
self._client = AsyncOpenAI(api_key=api_key)
self._aiohttp_session = aiohttp_session
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None

async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()

async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating image from prompt: {prompt}")
Expand Down
10 changes: 8 additions & 2 deletions src/pipecat/services/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,28 @@ class XTTSService(TTSService):
def __init__(
self,
*,
aiohttp_session: aiohttp.ClientSession,
voice_id: str,
language: str,
base_url: str,
aiohttp_session: aiohttp.ClientSession | None = None,
**kwargs):
super().__init__(**kwargs)

self._voice_id = voice_id
self._language = language
self._base_url = base_url
self._aiohttp_session = aiohttp_session
self._studio_speakers = requests.get(self._base_url + "/studio_speakers").json()
self._aiohttp_session = aiohttp_session or aiohttp.ClientSession()
self._close_aiohttp_session = aiohttp_session is None

def can_generate_metrics(self) -> bool:
return True

async def cleanup(self):
await super().cleanup()
if self._close_aiohttp_session:
await self._aiohttp_session.close()

async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice_id = voice
Expand Down

0 comments on commit 3db7f6a

Please sign in to comment.