Skip to content

Commit

Permalink
Merge pull request #926 from imsakg/main
Browse files Browse the repository at this point in the history
feat(gemini): add text handling to GeminiMultimodalLive
  • Loading branch information
markbackman authored Jan 8, 2025
2 parents 8057fe3 + 40e9ee6 commit 9dae753
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 2 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `enable_prejoin_ui`, `max_participants` and `start_video_off` params
to `DailyRoomProperties`.
- Added `session_timeout` to `FastAPIWebsocketTransport` and `WebsocketServerTransport`
for configuring session timeouts (in seconds). Triggers `on_session_timeout` for custom timeout handling.
for configuring session timeouts (in seconds). Triggers `on_session_timeout` for custom timeout handling.
See [examples/websocket-server/bot.py](https://github.com/pipecat-ai/pipecat/blob/main/examples/websocket-server/bot.py).
- Added the new modalities option and helper function to set Gemini output modalities.
- Added `examples/foundational/26d-gemini-multimodal-live-text.py` which is using Gemini as TEXT modality and using another TTS provider for TTS process.

### Changed

- api_key, aws_access_key_id and region are no longer required parameters for the PollyTTSService (AWSTTSService)
- Added `session_timeout` example in `examples/websocket-server/bot.py` to handle session timeout event.
- Changed `InputParams` in `src/pipecat/services/gemini_multimodal_live/gemini.py` to support different modalities.

### Fixed

Expand Down
85 changes: 85 additions & 0 deletions examples/foundational/26d-gemini-multimodal-live-text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import os
import sys

import aiohttp
from agent.services.tts.cartesia_multilingual import CartesiaMultiLingualTTSService
from dotenv import load_dotenv
from loguru import logger
from runner import configure

from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.services.gemini_multimodal_live.gemini import GeminiMultimodalLiveLLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport

load_dotenv(override=True)

logger.remove(0)
logger.add(sys.stderr, level="DEBUG")


async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)

transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_in_sample_rate=16000,
audio_out_sample_rate=24000,
audio_out_enabled=True,
vad_enabled=True,
vad_audio_passthrough=True,
# set stop_secs to something roughly similar to the internal setting
# of the Multimodal Live api, just to align events. This doesn't really
# matter because we can only use the Multimodal Live API's phrase
# endpointing, for now.
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.5)),
),
)

llm = GeminiMultimodalLiveLLMService(
api_key=os.getenv("GOOGLE_API_KEY"),
# system_instruction="Talk like a pirate."
)
llm.set_model_only_text() # This forces model to produce text only responses

tts = CartesiaMultiLingualTTSService(api_key=os.getenv("CARTESIA_API_KEY"))

pipeline = Pipeline(
[
transport.input(),
llm,
tts,
transport.output(),
]
)

task = PipelineTask(
pipeline,
PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
),
)

runner = PipelineRunner()

await runner.run(task)


if __name__ == "__main__":
asyncio.run(main())
1 change: 1 addition & 0 deletions src/pipecat/services/gemini_multimodal_live/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class InlineData(BaseModel):

class Part(BaseModel):
inlineData: Optional[InlineData] = None
text: Optional[str] = None


class ModelTurn(BaseModel):
Expand Down
31 changes: 30 additions & 1 deletion src/pipecat/services/gemini_multimodal_live/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import base64
import json
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional

import websockets
Expand Down Expand Up @@ -132,13 +133,21 @@ def assistant(self) -> GeminiMultimodalLiveAssistantContextAggregator:
return self._assistant


class GeminiMultimodalModalities(Enum):
TEXT = "TEXT"
AUDIO = "AUDIO"


class InputParams(BaseModel):
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(default=4096, ge=1)
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
top_k: Optional[int] = Field(default=None, ge=0)
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
modalities: Optional[GeminiMultimodalModalities] = Field(
default=GeminiMultimodalModalities.AUDIO
)
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)


Expand Down Expand Up @@ -188,6 +197,7 @@ def __init__(
self._bot_is_speaking = False
self._user_audio_buffer = bytearray()
self._bot_audio_buffer = bytearray()
self._bot_text_buffer = ""

self._settings = {
"frequency_penalty": params.frequency_penalty,
Expand All @@ -196,6 +206,7 @@ def __init__(
"temperature": params.temperature,
"top_k": params.top_k,
"top_p": params.top_p,
"modalities": params.modalities,
"extra": params.extra if isinstance(params.extra, dict) else {},
}

Expand All @@ -208,6 +219,9 @@ def set_audio_input_paused(self, paused: bool):
def set_video_input_paused(self, paused: bool):
self._video_input_paused = paused

def set_model_modalities(self, modalities: GeminiMultimodalModalities):
self._settings["modalities"] = modalities

async def set_context(self, context: OpenAILLMContext):
"""Set the context explicitly from outside the pipeline.
Expand Down Expand Up @@ -383,7 +397,7 @@ async def _connect(self):
"temperature": self._settings["temperature"],
"top_k": self._settings["top_k"],
"top_p": self._settings["top_p"],
"response_modalities": ["AUDIO"],
"response_modalities": self._settings["modalities"].value,
"speech_config": {
"voice_config": {
"prebuilt_voice_config": {"voice_name": self._voice_id}
Expand Down Expand Up @@ -604,6 +618,15 @@ async def _handle_evt_model_turn(self, evt):
part = evt.serverContent.modelTurn.parts[0]
if not part:
return

text = part.text
if text:
if not self._bot_text_buffer:
await self.push_frame(LLMFullResponseStartFrame())

self._bot_text_buffer += text
await self.push_frame(TextFrame(text=text))

inline_data = part.inlineData
if not inline_data:
return
Expand Down Expand Up @@ -644,9 +667,15 @@ async def _handle_evt_tool_call(self, evt):
async def _handle_evt_turn_complete(self, evt):
self._bot_is_speaking = False
audio = self._bot_audio_buffer
text = self._bot_text_buffer
self._bot_audio_buffer = bytearray()
self._bot_text_buffer = ""

if audio and self._transcribe_model_audio and self._context:
asyncio.create_task(self._handle_transcribe_model_audio(audio, self._context))
elif text:
await self.push_frame(LLMFullResponseEndFrame())

await self.push_frame(TTSStoppedFrame())

def create_context_aggregator(
Expand Down

0 comments on commit 9dae753

Please sign in to comment.