diff --git a/CHANGELOG.md b/CHANGELOG.md index 6adf6a994..eaccd2d26 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,13 +16,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `enable_prejoin_ui`, `max_participants` and `start_video_off` params to `DailyRoomProperties`. - Added `session_timeout` to `FastAPIWebsocketTransport` and `WebsocketServerTransport` - for configuring session timeouts (in seconds). Triggers `on_session_timeout` for custom timeout handling. + for configuring session timeouts (in seconds). Triggers `on_session_timeout` for custom timeout handling. See [examples/websocket-server/bot.py](https://github.com/pipecat-ai/pipecat/blob/main/examples/websocket-server/bot.py). +- Added the new modalities option and helper function to set Gemini output modalities. +- Added `examples/foundational/26d-gemini-multimodal-live-text.py` which is using Gemini as TEXT modality and using another TTS provider for TTS process. ### Changed - api_key, aws_access_key_id and region are no longer required parameters for the PollyTTSService (AWSTTSService) - Added `session_timeout` example in `examples/websocket-server/bot.py` to handle session timeout event. +- Changed `InputParams` in `src/pipecat/services/gemini_multimodal_live/gemini.py` to support different modalities. ### Fixed diff --git a/examples/foundational/26d-gemini-multimodal-live-text.py b/examples/foundational/26d-gemini-multimodal-live-text.py new file mode 100644 index 000000000..760af39ce --- /dev/null +++ b/examples/foundational/26d-gemini-multimodal-live-text.py @@ -0,0 +1,85 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys + +import aiohttp +from agent.services.tts.cartesia_multilingual import CartesiaMultiLingualTTSService +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.audio.vad.vad_analyzer import VADParams +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.services.gemini_multimodal_live.gemini import GeminiMultimodalLiveLLMService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_in_sample_rate=16000, + audio_out_sample_rate=24000, + audio_out_enabled=True, + vad_enabled=True, + vad_audio_passthrough=True, + # set stop_secs to something roughly similar to the internal setting + # of the Multimodal Live api, just to align events. This doesn't really + # matter because we can only use the Multimodal Live API's phrase + # endpointing, for now. + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.5)), + ), + ) + + llm = GeminiMultimodalLiveLLMService( + api_key=os.getenv("GOOGLE_API_KEY"), + # system_instruction="Talk like a pirate." + ) + llm.set_model_only_text() # This forces model to produce text only responses + + tts = CartesiaMultiLingualTTSService(api_key=os.getenv("CARTESIA_API_KEY")) + + pipeline = Pipeline( + [ + transport.input(), + llm, + tts, + transport.output(), + ] + ) + + task = PipelineTask( + pipeline, + PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + ), + ) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/pipecat/services/gemini_multimodal_live/events.py b/src/pipecat/services/gemini_multimodal_live/events.py index 0d5bc802f..36541aa30 100644 --- a/src/pipecat/services/gemini_multimodal_live/events.py +++ b/src/pipecat/services/gemini_multimodal_live/events.py @@ -105,6 +105,7 @@ class InlineData(BaseModel): class Part(BaseModel): inlineData: Optional[InlineData] = None + text: Optional[str] = None class ModelTurn(BaseModel): diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index dd4375486..1d76f191c 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -8,6 +8,7 @@ import base64 import json from dataclasses import dataclass +from enum import Enum from typing import Any, Dict, List, Optional import websockets @@ -132,6 +133,11 @@ def assistant(self) -> GeminiMultimodalLiveAssistantContextAggregator: return self._assistant +class GeminiMultimodalModalities(Enum): + TEXT = "TEXT" + AUDIO = "AUDIO" + + class InputParams(BaseModel): frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) max_tokens: Optional[int] = Field(default=4096, ge=1) @@ -139,6 +145,9 @@ class InputParams(BaseModel): temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) top_k: Optional[int] = Field(default=None, ge=0) top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) + modalities: Optional[GeminiMultimodalModalities] = Field( + default=GeminiMultimodalModalities.AUDIO + ) extra: Optional[Dict[str, Any]] = Field(default_factory=dict) @@ -188,6 +197,7 @@ def __init__( self._bot_is_speaking = False self._user_audio_buffer = bytearray() self._bot_audio_buffer = bytearray() + self._bot_text_buffer = "" self._settings = { "frequency_penalty": params.frequency_penalty, @@ -196,6 +206,7 @@ def __init__( "temperature": params.temperature, "top_k": params.top_k, "top_p": params.top_p, + "modalities": params.modalities, "extra": params.extra if isinstance(params.extra, dict) else {}, } @@ -208,6 +219,9 @@ def set_audio_input_paused(self, paused: bool): def set_video_input_paused(self, paused: bool): self._video_input_paused = paused + def set_model_modalities(self, modalities: GeminiMultimodalModalities): + self._settings["modalities"] = modalities + async def set_context(self, context: OpenAILLMContext): """Set the context explicitly from outside the pipeline. @@ -383,7 +397,7 @@ async def _connect(self): "temperature": self._settings["temperature"], "top_k": self._settings["top_k"], "top_p": self._settings["top_p"], - "response_modalities": ["AUDIO"], + "response_modalities": self._settings["modalities"].value, "speech_config": { "voice_config": { "prebuilt_voice_config": {"voice_name": self._voice_id} @@ -604,6 +618,15 @@ async def _handle_evt_model_turn(self, evt): part = evt.serverContent.modelTurn.parts[0] if not part: return + + text = part.text + if text: + if not self._bot_text_buffer: + await self.push_frame(LLMFullResponseStartFrame()) + + self._bot_text_buffer += text + await self.push_frame(TextFrame(text=text)) + inline_data = part.inlineData if not inline_data: return @@ -644,9 +667,15 @@ async def _handle_evt_tool_call(self, evt): async def _handle_evt_turn_complete(self, evt): self._bot_is_speaking = False audio = self._bot_audio_buffer + text = self._bot_text_buffer self._bot_audio_buffer = bytearray() + self._bot_text_buffer = "" + if audio and self._transcribe_model_audio and self._context: asyncio.create_task(self._handle_transcribe_model_audio(audio, self._context)) + elif text: + await self.push_frame(LLMFullResponseEndFrame()) + await self.push_frame(TTSStoppedFrame()) def create_context_aggregator(