Skip to content

Commit

Permalink
feat(python/sdk): Add multichannel support (#6814)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 668f1b8c9e1339b818ef321887ab7e8fe363d58c
  • Loading branch information
ploeber committed Nov 6, 2024
1 parent 39b0552 commit d2ac671
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 3 deletions.
28 changes: 27 additions & 1 deletion assemblyai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,9 @@ class RawTranscriptionConfig(BaseModel):
dual_channel: Optional[bool] = None
"Enable Dual Channel transcription"

multichannel: Optional[bool] = None
"Enable Multichannel transcription"

webhook_url: Optional[str] = None
"The URL we should send webhooks to when your transcript is complete."
webhook_auth_header_name: Optional[str] = None
Expand Down Expand Up @@ -578,6 +581,7 @@ def __init__(
punctuate: Optional[bool] = None,
format_text: Optional[bool] = None,
dual_channel: Optional[bool] = None,
multichannel: Optional[bool] = None,
webhook_url: Optional[str] = None,
webhook_auth_header_name: Optional[str] = None,
webhook_auth_header_value: Optional[str] = None,
Expand Down Expand Up @@ -617,6 +621,7 @@ def __init__(
punctuate: Enable Automatic Punctuation
format_text: Enable Text Formatting
dual_channel: Enable Dual Channel transcription
multichannel: Enable Multichannel transcription
webhoook_url: The URL we should send webhooks to when your transcript is complete.
webhook_auth_header_name: The name of the header that is sent when the `webhook_url` is being called.
webhook_auth_header_value: The value of the `webhook_auth_header_name` that is sent when the `webhoook_url` is being called.
Expand Down Expand Up @@ -660,6 +665,7 @@ def __init__(
self.punctuate = punctuate
self.format_text = format_text
self.dual_channel = dual_channel
self.multichannel = multichannel
self.set_webhook(
webhook_url,
webhook_auth_header_name,
Expand Down Expand Up @@ -760,6 +766,18 @@ def dual_channel(self, enable: Optional[bool]) -> None:

self._raw_transcription_config.dual_channel = enable

@property
def multichannel(self) -> Optional[bool]:
"Returns the status of the Multichannel transcription feature"

return self._raw_transcription_config.multichannel

@multichannel.setter
def multichannel(self, enable: Optional[bool]) -> None:
"Enable Multichannel transcription"

self._raw_transcription_config.multichannel = enable

@property
def webhook_url(self) -> Optional[str]:
"The URL we should send webhooks to when your transcript is complete."
Expand Down Expand Up @@ -1391,6 +1409,7 @@ class Word(BaseModel):
end: int
confidence: float
speaker: Optional[str] = None
channel: Optional[str] = None


class UtteranceWord(Word):
Expand Down Expand Up @@ -1485,6 +1504,7 @@ class IABResponse(BaseModel):
class Sentiment(Word):
sentiment: SentimentType
speaker: Optional[str] = None
channel: Optional[str] = None


class Entity(BaseModel):
Expand Down Expand Up @@ -1530,6 +1550,7 @@ class Sentence(Word):
end: int
confidence: float
speaker: Optional[str] = None
channel: Optional[str] = None


class SentencesResponse(BaseModel):
Expand Down Expand Up @@ -1576,6 +1597,11 @@ class BaseTranscript(BaseModel):
dual_channel: Optional[bool] = None
"Enable Dual Channel transcription"

multichannel: Optional[bool] = None
"Enable Multichannel transcription"
audio_channels: Optional[int] = None
"The number of audio channels in the media file"

webhook_url: Optional[str] = None
"The URL we should send webhooks to when your transcript is complete."
webhook_auth_header_name: Optional[str] = None
Expand Down Expand Up @@ -1694,7 +1720,7 @@ class TranscriptResponse(BaseTranscript):
"A list of all the individual words transcribed"

utterances: Optional[List[Utterance]] = None
"When `dual_channel` or `speaker_labels` is enabled, a list of turn-by-turn utterances"
"When `dual_channel`, `multichannel`, or `speaker_labels` is enabled, a list of turn-by-turn utterances"

confidence: Optional[float] = None
"The confidence our model has in the transcribed text, between 0.0 and 1.0"
Expand Down
9 changes: 7 additions & 2 deletions tests/unit/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ class Meta:
start = factory.Faker("pyint")
end = factory.Faker("pyint")
confidence = factory.Faker("pyfloat", min_value=0.0, max_value=1.0)
speaker = "1"
channel = "1"


class UtteranceWordFactory(WordFactory):
class Meta:
model = aai.UtteranceWord

speaker = factory.Faker("name")
speaker = "1"
channel = "1"


class UtteranceFactory(UtteranceWordFactory):
Expand Down Expand Up @@ -65,7 +68,8 @@ class Meta:
audio_url = factory.Faker("url")
punctuate = True
format_text = True
dual_channel = True
multichannel = None
dual_channel = None
webhook_url = None
webhook_auth_header_name = None
audio_start_from = None
Expand Down Expand Up @@ -119,6 +123,7 @@ class TranscriptDeletedResponseFactory(BaseTranscriptResponseFactory):
punctuate = None
format_text = None
dual_channel = None
multichannel = None
webhook_url = "http://deleted_by_user"
webhook_status_code = None
webhook_auth = False
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/test_multichannel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pytest_httpx import HTTPXMock

import tests.unit.unit_test_utils as unit_test_utils
import assemblyai as aai
from tests.unit import factories

aai.settings.api_key = "test"


class MultichannelResponseFactory(factories.TranscriptCompletedResponseFactory):
multichannel = True
audio_channels = 2


def test_multichannel_disabled_by_default(httpx_mock: HTTPXMock):
"""
Tests that not setting `multichannel=True` in the `TranscriptionConfig`
will result in the default behavior of it being excluded from the request body.
"""
request_body, transcript = unit_test_utils.submit_mock_transcription_request(
httpx_mock,
mock_response=factories.generate_dict_factory(
factories.TranscriptCompletedResponseFactory
)(),
config=aai.TranscriptionConfig(),
)
assert request_body.get("multichannel") is None
assert transcript.json_response.get("multichannel") is None


def test_multichannel_enabled(httpx_mock: HTTPXMock):
"""
Tests that not setting `multichannel=True` in the `TranscriptionConfig`
will result in correct `multichannel` in the request body, and that the
response is properly parsed into the `multichannel` and `utterances` field.
"""

mock_response = factories.generate_dict_factory(MultichannelResponseFactory)()
request_body, transcript = unit_test_utils.submit_mock_transcription_request(
httpx_mock,
mock_response=mock_response,
config=aai.TranscriptionConfig(multichannel=True),
)

# Check that request body was properly defined
multichannel_response = request_body.get("multichannel")
assert multichannel_response is not None

# Check that transcript has no errors and multichannel response is correctly returned
assert transcript.error is None
assert transcript.json_response["multichannel"] == multichannel_response
assert transcript.json_response["audio_channels"] > 1

# Check that utterances are correctly parsed
assert transcript.utterances is not None
assert len(transcript.utterances) > 0
for utterance in transcript.utterances:
assert int(utterance.channel) > 0

0 comments on commit d2ac671

Please sign in to comment.