From d2ac6710590ac9db0161d0dd81ce554ad293fc06 Mon Sep 17 00:00:00 2001 From: Patrick Loeber <98830383+ploeber@users.noreply.github.com> Date: Fri, 18 Oct 2024 15:41:18 +0200 Subject: [PATCH] feat(python/sdk): Add multichannel support (#6814) GitOrigin-RevId: 668f1b8c9e1339b818ef321887ab7e8fe363d58c --- assemblyai/types.py | 28 +++++++++++++++- tests/unit/factories.py | 9 +++-- tests/unit/test_multichannel.py | 58 +++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 3 deletions(-) create mode 100644 tests/unit/test_multichannel.py diff --git a/assemblyai/types.py b/assemblyai/types.py index 269b786..daaff28 100644 --- a/assemblyai/types.py +++ b/assemblyai/types.py @@ -477,6 +477,9 @@ class RawTranscriptionConfig(BaseModel): dual_channel: Optional[bool] = None "Enable Dual Channel transcription" + multichannel: Optional[bool] = None + "Enable Multichannel transcription" + webhook_url: Optional[str] = None "The URL we should send webhooks to when your transcript is complete." webhook_auth_header_name: Optional[str] = None @@ -578,6 +581,7 @@ def __init__( punctuate: Optional[bool] = None, format_text: Optional[bool] = None, dual_channel: Optional[bool] = None, + multichannel: Optional[bool] = None, webhook_url: Optional[str] = None, webhook_auth_header_name: Optional[str] = None, webhook_auth_header_value: Optional[str] = None, @@ -617,6 +621,7 @@ def __init__( punctuate: Enable Automatic Punctuation format_text: Enable Text Formatting dual_channel: Enable Dual Channel transcription + multichannel: Enable Multichannel transcription webhoook_url: The URL we should send webhooks to when your transcript is complete. webhook_auth_header_name: The name of the header that is sent when the `webhook_url` is being called. webhook_auth_header_value: The value of the `webhook_auth_header_name` that is sent when the `webhoook_url` is being called. @@ -660,6 +665,7 @@ def __init__( self.punctuate = punctuate self.format_text = format_text self.dual_channel = dual_channel + self.multichannel = multichannel self.set_webhook( webhook_url, webhook_auth_header_name, @@ -760,6 +766,18 @@ def dual_channel(self, enable: Optional[bool]) -> None: self._raw_transcription_config.dual_channel = enable + @property + def multichannel(self) -> Optional[bool]: + "Returns the status of the Multichannel transcription feature" + + return self._raw_transcription_config.multichannel + + @multichannel.setter + def multichannel(self, enable: Optional[bool]) -> None: + "Enable Multichannel transcription" + + self._raw_transcription_config.multichannel = enable + @property def webhook_url(self) -> Optional[str]: "The URL we should send webhooks to when your transcript is complete." @@ -1391,6 +1409,7 @@ class Word(BaseModel): end: int confidence: float speaker: Optional[str] = None + channel: Optional[str] = None class UtteranceWord(Word): @@ -1485,6 +1504,7 @@ class IABResponse(BaseModel): class Sentiment(Word): sentiment: SentimentType speaker: Optional[str] = None + channel: Optional[str] = None class Entity(BaseModel): @@ -1530,6 +1550,7 @@ class Sentence(Word): end: int confidence: float speaker: Optional[str] = None + channel: Optional[str] = None class SentencesResponse(BaseModel): @@ -1576,6 +1597,11 @@ class BaseTranscript(BaseModel): dual_channel: Optional[bool] = None "Enable Dual Channel transcription" + multichannel: Optional[bool] = None + "Enable Multichannel transcription" + audio_channels: Optional[int] = None + "The number of audio channels in the media file" + webhook_url: Optional[str] = None "The URL we should send webhooks to when your transcript is complete." webhook_auth_header_name: Optional[str] = None @@ -1694,7 +1720,7 @@ class TranscriptResponse(BaseTranscript): "A list of all the individual words transcribed" utterances: Optional[List[Utterance]] = None - "When `dual_channel` or `speaker_labels` is enabled, a list of turn-by-turn utterances" + "When `dual_channel`, `multichannel`, or `speaker_labels` is enabled, a list of turn-by-turn utterances" confidence: Optional[float] = None "The confidence our model has in the transcribed text, between 0.0 and 1.0" diff --git a/tests/unit/factories.py b/tests/unit/factories.py index d3c9724..6e66c52 100644 --- a/tests/unit/factories.py +++ b/tests/unit/factories.py @@ -30,13 +30,16 @@ class Meta: start = factory.Faker("pyint") end = factory.Faker("pyint") confidence = factory.Faker("pyfloat", min_value=0.0, max_value=1.0) + speaker = "1" + channel = "1" class UtteranceWordFactory(WordFactory): class Meta: model = aai.UtteranceWord - speaker = factory.Faker("name") + speaker = "1" + channel = "1" class UtteranceFactory(UtteranceWordFactory): @@ -65,7 +68,8 @@ class Meta: audio_url = factory.Faker("url") punctuate = True format_text = True - dual_channel = True + multichannel = None + dual_channel = None webhook_url = None webhook_auth_header_name = None audio_start_from = None @@ -119,6 +123,7 @@ class TranscriptDeletedResponseFactory(BaseTranscriptResponseFactory): punctuate = None format_text = None dual_channel = None + multichannel = None webhook_url = "http://deleted_by_user" webhook_status_code = None webhook_auth = False diff --git a/tests/unit/test_multichannel.py b/tests/unit/test_multichannel.py new file mode 100644 index 0000000..72a2bd4 --- /dev/null +++ b/tests/unit/test_multichannel.py @@ -0,0 +1,58 @@ +from pytest_httpx import HTTPXMock + +import tests.unit.unit_test_utils as unit_test_utils +import assemblyai as aai +from tests.unit import factories + +aai.settings.api_key = "test" + + +class MultichannelResponseFactory(factories.TranscriptCompletedResponseFactory): + multichannel = True + audio_channels = 2 + + +def test_multichannel_disabled_by_default(httpx_mock: HTTPXMock): + """ + Tests that not setting `multichannel=True` in the `TranscriptionConfig` + will result in the default behavior of it being excluded from the request body. + """ + request_body, transcript = unit_test_utils.submit_mock_transcription_request( + httpx_mock, + mock_response=factories.generate_dict_factory( + factories.TranscriptCompletedResponseFactory + )(), + config=aai.TranscriptionConfig(), + ) + assert request_body.get("multichannel") is None + assert transcript.json_response.get("multichannel") is None + + +def test_multichannel_enabled(httpx_mock: HTTPXMock): + """ + Tests that not setting `multichannel=True` in the `TranscriptionConfig` + will result in correct `multichannel` in the request body, and that the + response is properly parsed into the `multichannel` and `utterances` field. + """ + + mock_response = factories.generate_dict_factory(MultichannelResponseFactory)() + request_body, transcript = unit_test_utils.submit_mock_transcription_request( + httpx_mock, + mock_response=mock_response, + config=aai.TranscriptionConfig(multichannel=True), + ) + + # Check that request body was properly defined + multichannel_response = request_body.get("multichannel") + assert multichannel_response is not None + + # Check that transcript has no errors and multichannel response is correctly returned + assert transcript.error is None + assert transcript.json_response["multichannel"] == multichannel_response + assert transcript.json_response["audio_channels"] > 1 + + # Check that utterances are correctly parsed + assert transcript.utterances is not None + assert len(transcript.utterances) > 0 + for utterance in transcript.utterances: + assert int(utterance.channel) > 0