Skip to content

Commit

Permalink
chore: sync code base with OSS repository (#53)
Browse files Browse the repository at this point in the history
Co-authored-by: Aleks Mitov <[email protected]>
  • Loading branch information
ploeber and aleks-mitov authored Feb 15, 2024
1 parent 1642920 commit 7a41020
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 52 deletions.
4 changes: 2 additions & 2 deletions assemblyai/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _get_error_message(response: httpx.Response) -> str:
try:
return response.json()["error"]
except Exception:
return response.text
return f"\nReason: {response.text}\nRequest: {response.request}"


def create_transcript(
Expand All @@ -43,7 +43,7 @@ def create_transcript(
)
if response.status_code != httpx.codes.ok:
raise types.TranscriptError(
f"failed to transcript url {request.audio_url}: {_get_error_message(response)}"
f"failed to transcribe url {request.audio_url}: {_get_error_message(response)}"
)

return types.TranscriptResponse.parse_obj(response.json())
Expand Down
92 changes: 74 additions & 18 deletions assemblyai/transcriber.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import base64
import concurrent.futures
import functools
import json
Expand Down Expand Up @@ -987,6 +986,7 @@ def __init__(
encoding: Optional[types.AudioEncoding] = None,
token: Optional[str] = None,
client: _client.Client,
end_utterance_silence_threshold: Optional[int],
) -> None:
self._client = client
self._websocket: Optional[websockets.sync.client.ClientConnection] = None
Expand All @@ -999,8 +999,9 @@ def __init__(
self._word_boost = word_boost
self._encoding = encoding
self._token = token
self._end_utterance_silence_threshold = end_utterance_silence_threshold

self._write_queue: queue.Queue[bytes] = queue.Queue()
self._write_queue: queue.Queue[Union[bytes, Dict]] = queue.Queue()
self._write_thread = threading.Thread(target=self._write)
self._read_thread = threading.Thread(target=self._read)
self._stop_event = threading.Event()
Expand Down Expand Up @@ -1048,13 +1049,40 @@ def connect(
self._read_thread.start()
self._write_thread.start()

if self._end_utterance_silence_threshold is not None:
self.configure_end_utterance_silence_threshold(
self._end_utterance_silence_threshold
)

def stream(self, data: bytes) -> None:
"""
Streams audio data to the real-time service by putting it into a queue.
"""

self._write_queue.put(data)

def configure_end_utterance_silence_threshold(
self, threshold_milliseconds: int
) -> None:
"""
Configures the end of utterance silence threshold.
Can be called multiple times during a session at any point after the session starts.
Args:
`threshold_milliseconds`: The threshold in milliseconds.
"""

self._write_queue.put(
_RealtimeEndUtteranceSilenceThreshold(threshold_milliseconds).as_dict()
)

def force_end_utterance(self) -> None:
"""
Forces the end of the current utterance.
"""

self._write_queue.put(_RealtimeForceEndUtterance().as_dict())

def close(self, terminate: bool = False) -> None:
"""
Closes the connection to the real-time service gracefully.
Expand Down Expand Up @@ -1116,25 +1144,12 @@ def _write(self) -> None:
if isinstance(data, dict):
self._websocket.send(json.dumps(data))
elif isinstance(data, bytes):
self._websocket.send(self._encode_data(data))
self._websocket.send(data)
else:
raise ValueError("unsupported message type")
except websockets.exceptions.ConnectionClosed as exc:
return self._handle_error(exc)

def _encode_data(self, data: bytes) -> str:
"""
Encodes the given audio chunk as a base64 string.
This is a helper method for `_write`.
"""

return json.dumps(
{
"audio_data": base64.b64encode(data).decode("utf-8"),
}
)

def _handle_message(
self,
message: Dict[str, Any],
Expand Down Expand Up @@ -1208,6 +1223,25 @@ def create_temporary_token(
)


class _RealtimeForceEndUtterance:
def as_dict(self) -> Dict[str, bool]:
return {
"force_end_utterance": True,
}


class _RealtimeEndUtteranceSilenceThreshold:
def __init__(self, threshold_milliseconds: int) -> None:
self._value = threshold_milliseconds

@property
def value(self) -> int:
return self._value

def as_dict(self) -> Dict[str, int]:
return {"end_utterance_silence_threshold": self._value}


class RealtimeTranscriber:
def __init__(
self,
Expand All @@ -1221,6 +1255,7 @@ def __init__(
encoding: Optional[types.AudioEncoding] = None,
token: Optional[str] = None,
client: Optional[_client.Client] = None,
end_utterance_silence_threshold: Optional[int] = None,
) -> None:
"""
Creates a new real-time transcriber.
Expand All @@ -1235,6 +1270,7 @@ def __init__(
`encoding`: (Optional) The encoding of the audio data.
`token`: (Optional) A temporary authentication token.
`client`: (Optional) The client to use for the real-time service.
`end_utterance_silence_threshold`: (Optional) The end utterance silence threshold in milliseconds.
"""

self._client = client or _client.Client.get_default(
Expand All @@ -1251,6 +1287,7 @@ def __init__(
encoding=encoding,
token=token,
client=self._client,
end_utterance_silence_threshold=end_utterance_silence_threshold,
)

def connect(
Expand All @@ -1268,8 +1305,7 @@ def connect(
self._impl.connect(timeout=timeout)

def stream(
self,
data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]],
self, data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]]
) -> None:
"""
Streams raw audio data to the real-time service.
Expand All @@ -1286,6 +1322,26 @@ def stream(
for chunk in data:
self._impl.stream(chunk)

def configure_end_utterance_silence_threshold(
self, threshold_milliseconds: int
) -> None:
"""
Configures the silence duration threshold used to detect the end of an utterance.
In practice, it's used to tune how the transcriptions are split into final transcripts.
Can be called multiple times during a session at any point after the session starts.
Args:
`threshold_milliseconds`: The threshold in milliseconds.
"""
self._impl.configure_end_utterance_silence_threshold(threshold_milliseconds)

def force_end_utterance(self) -> None:
"""
Forces the end of the current utterance.
After calling this method, the server will end the current utterance and return a final transcript.
"""
self._impl.force_end_utterance()

def close(self) -> None:
"""
Closes the connection to the real-time service.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="assemblyai",
version="0.21.0",
version="0.22.0",
description="AssemblyAI Python SDK",
author="AssemblyAI",
author_email="[email protected]",
Expand Down
34 changes: 3 additions & 31 deletions tests/unit/test_realtime_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,38 +274,10 @@ def mocked_send(data: str):

transcriber._impl._write()

# assert that the correct data was sent (base64 encoded)
# assert that the correct data was sent (= the exact input bytes)
assert len(actual_sent) == 2
assert json.loads(actual_sent[0]) == {"audio_data": "AQIDBAU="}
assert json.loads(actual_sent[1]) == {"audio_data": "BgcICQo="}


def test_realtime__encode_data(mocker: MockFixture):
"""
Tests the `_encode_data` method of the `_RealtimeTranscriberImpl` class.
"""

audio_chunks = [
bytes([1, 2, 3, 4, 5]),
bytes([6, 7, 8, 9, 10]),
]

expected_encoded_data = [
json.dumps({"audio_data": "AQIDBAU="}),
json.dumps({"audio_data": "BgcICQo="}),
]

transcriber = aai.RealtimeTranscriber(
on_data=lambda _: None,
on_error=lambda _: None,
sample_rate=44_100,
)

actual_encoded_data = []
for chunk in audio_chunks:
actual_encoded_data.append(transcriber._impl._encode_data(chunk))

assert actual_encoded_data == expected_encoded_data
assert actual_sent[0] == audio_chunks[0]
assert actual_sent[1] == audio_chunks[1]


def test_realtime__handle_message_session_begins(mocker: MockFixture):
Expand Down

0 comments on commit 7a41020

Please sign in to comment.