From 7a4102032a90f8c97a57a04cb1e67f99f74954fb Mon Sep 17 00:00:00 2001 From: Patrick Loeber <98830383+ploeber@users.noreply.github.com> Date: Thu, 15 Feb 2024 13:38:30 +0100 Subject: [PATCH] chore: sync code base with OSS repository (#53) Co-authored-by: Aleks Mitov <140423361+aleks-mitov@users.noreply.github.com> --- assemblyai/api.py | 4 +- assemblyai/transcriber.py | 92 ++++++++++++++++++++----- setup.py | 2 +- tests/unit/test_realtime_transcriber.py | 34 +-------- 4 files changed, 80 insertions(+), 52 deletions(-) diff --git a/assemblyai/api.py b/assemblyai/api.py index d3d87d2..1362204 100644 --- a/assemblyai/api.py +++ b/assemblyai/api.py @@ -27,7 +27,7 @@ def _get_error_message(response: httpx.Response) -> str: try: return response.json()["error"] except Exception: - return response.text + return f"\nReason: {response.text}\nRequest: {response.request}" def create_transcript( @@ -43,7 +43,7 @@ def create_transcript( ) if response.status_code != httpx.codes.ok: raise types.TranscriptError( - f"failed to transcript url {request.audio_url}: {_get_error_message(response)}" + f"failed to transcribe url {request.audio_url}: {_get_error_message(response)}" ) return types.TranscriptResponse.parse_obj(response.json()) diff --git a/assemblyai/transcriber.py b/assemblyai/transcriber.py index 913563b..e64d482 100644 --- a/assemblyai/transcriber.py +++ b/assemblyai/transcriber.py @@ -1,6 +1,5 @@ from __future__ import annotations -import base64 import concurrent.futures import functools import json @@ -987,6 +986,7 @@ def __init__( encoding: Optional[types.AudioEncoding] = None, token: Optional[str] = None, client: _client.Client, + end_utterance_silence_threshold: Optional[int], ) -> None: self._client = client self._websocket: Optional[websockets.sync.client.ClientConnection] = None @@ -999,8 +999,9 @@ def __init__( self._word_boost = word_boost self._encoding = encoding self._token = token + self._end_utterance_silence_threshold = end_utterance_silence_threshold - self._write_queue: queue.Queue[bytes] = queue.Queue() + self._write_queue: queue.Queue[Union[bytes, Dict]] = queue.Queue() self._write_thread = threading.Thread(target=self._write) self._read_thread = threading.Thread(target=self._read) self._stop_event = threading.Event() @@ -1048,6 +1049,11 @@ def connect( self._read_thread.start() self._write_thread.start() + if self._end_utterance_silence_threshold is not None: + self.configure_end_utterance_silence_threshold( + self._end_utterance_silence_threshold + ) + def stream(self, data: bytes) -> None: """ Streams audio data to the real-time service by putting it into a queue. @@ -1055,6 +1061,28 @@ def stream(self, data: bytes) -> None: self._write_queue.put(data) + def configure_end_utterance_silence_threshold( + self, threshold_milliseconds: int + ) -> None: + """ + Configures the end of utterance silence threshold. + Can be called multiple times during a session at any point after the session starts. + + Args: + `threshold_milliseconds`: The threshold in milliseconds. + """ + + self._write_queue.put( + _RealtimeEndUtteranceSilenceThreshold(threshold_milliseconds).as_dict() + ) + + def force_end_utterance(self) -> None: + """ + Forces the end of the current utterance. + """ + + self._write_queue.put(_RealtimeForceEndUtterance().as_dict()) + def close(self, terminate: bool = False) -> None: """ Closes the connection to the real-time service gracefully. @@ -1116,25 +1144,12 @@ def _write(self) -> None: if isinstance(data, dict): self._websocket.send(json.dumps(data)) elif isinstance(data, bytes): - self._websocket.send(self._encode_data(data)) + self._websocket.send(data) else: raise ValueError("unsupported message type") except websockets.exceptions.ConnectionClosed as exc: return self._handle_error(exc) - def _encode_data(self, data: bytes) -> str: - """ - Encodes the given audio chunk as a base64 string. - - This is a helper method for `_write`. - """ - - return json.dumps( - { - "audio_data": base64.b64encode(data).decode("utf-8"), - } - ) - def _handle_message( self, message: Dict[str, Any], @@ -1208,6 +1223,25 @@ def create_temporary_token( ) +class _RealtimeForceEndUtterance: + def as_dict(self) -> Dict[str, bool]: + return { + "force_end_utterance": True, + } + + +class _RealtimeEndUtteranceSilenceThreshold: + def __init__(self, threshold_milliseconds: int) -> None: + self._value = threshold_milliseconds + + @property + def value(self) -> int: + return self._value + + def as_dict(self) -> Dict[str, int]: + return {"end_utterance_silence_threshold": self._value} + + class RealtimeTranscriber: def __init__( self, @@ -1221,6 +1255,7 @@ def __init__( encoding: Optional[types.AudioEncoding] = None, token: Optional[str] = None, client: Optional[_client.Client] = None, + end_utterance_silence_threshold: Optional[int] = None, ) -> None: """ Creates a new real-time transcriber. @@ -1235,6 +1270,7 @@ def __init__( `encoding`: (Optional) The encoding of the audio data. `token`: (Optional) A temporary authentication token. `client`: (Optional) The client to use for the real-time service. + `end_utterance_silence_threshold`: (Optional) The end utterance silence threshold in milliseconds. """ self._client = client or _client.Client.get_default( @@ -1251,6 +1287,7 @@ def __init__( encoding=encoding, token=token, client=self._client, + end_utterance_silence_threshold=end_utterance_silence_threshold, ) def connect( @@ -1268,8 +1305,7 @@ def connect( self._impl.connect(timeout=timeout) def stream( - self, - data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]], + self, data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]] ) -> None: """ Streams raw audio data to the real-time service. @@ -1286,6 +1322,26 @@ def stream( for chunk in data: self._impl.stream(chunk) + def configure_end_utterance_silence_threshold( + self, threshold_milliseconds: int + ) -> None: + """ + Configures the silence duration threshold used to detect the end of an utterance. + In practice, it's used to tune how the transcriptions are split into final transcripts. + Can be called multiple times during a session at any point after the session starts. + + Args: + `threshold_milliseconds`: The threshold in milliseconds. + """ + self._impl.configure_end_utterance_silence_threshold(threshold_milliseconds) + + def force_end_utterance(self) -> None: + """ + Forces the end of the current utterance. + After calling this method, the server will end the current utterance and return a final transcript. + """ + self._impl.force_end_utterance() + def close(self) -> None: """ Closes the connection to the real-time service. diff --git a/setup.py b/setup.py index 5411478..2b0e666 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name="assemblyai", - version="0.21.0", + version="0.22.0", description="AssemblyAI Python SDK", author="AssemblyAI", author_email="engineering.sdk@assemblyai.com", diff --git a/tests/unit/test_realtime_transcriber.py b/tests/unit/test_realtime_transcriber.py index 9c46d00..aa4da55 100644 --- a/tests/unit/test_realtime_transcriber.py +++ b/tests/unit/test_realtime_transcriber.py @@ -274,38 +274,10 @@ def mocked_send(data: str): transcriber._impl._write() - # assert that the correct data was sent (base64 encoded) + # assert that the correct data was sent (= the exact input bytes) assert len(actual_sent) == 2 - assert json.loads(actual_sent[0]) == {"audio_data": "AQIDBAU="} - assert json.loads(actual_sent[1]) == {"audio_data": "BgcICQo="} - - -def test_realtime__encode_data(mocker: MockFixture): - """ - Tests the `_encode_data` method of the `_RealtimeTranscriberImpl` class. - """ - - audio_chunks = [ - bytes([1, 2, 3, 4, 5]), - bytes([6, 7, 8, 9, 10]), - ] - - expected_encoded_data = [ - json.dumps({"audio_data": "AQIDBAU="}), - json.dumps({"audio_data": "BgcICQo="}), - ] - - transcriber = aai.RealtimeTranscriber( - on_data=lambda _: None, - on_error=lambda _: None, - sample_rate=44_100, - ) - - actual_encoded_data = [] - for chunk in audio_chunks: - actual_encoded_data.append(transcriber._impl._encode_data(chunk)) - - assert actual_encoded_data == expected_encoded_data + assert actual_sent[0] == audio_chunks[0] + assert actual_sent[1] == audio_chunks[1] def test_realtime__handle_message_session_begins(mocker: MockFixture):