Skip to content

Commit

Permalink
fix(python/sdk): Fix custom spelling property and add unit tests (#6763)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: c5d69d94605abad2bef7510b8ba61bfa1d079db3
  • Loading branch information
ploeber committed Oct 18, 2024
1 parent 47bfa69 commit c9f35e3
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 10 deletions.
24 changes: 14 additions & 10 deletions assemblyai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ class RawTranscriptionConfig(BaseModel):
"Enable Topic Detection."

custom_spelling: Optional[List[Dict[str, Union[str, List[str]]]]] = None
"Customize how words are spelled and formatted using to and from values"
"Customize how words are spelled and formatted using to and from values."

disfluencies: Optional[bool] = None
"Transcribe Filler Words, like 'umm', in your media file."
Expand Down Expand Up @@ -916,18 +916,21 @@ def iab_categories(self, enable: Optional[bool]) -> None:

@property
def custom_spelling(self) -> Optional[Dict[str, Union[str, List[str]]]]:
"Returns the current set custom spellings."
"""
Returns the current set of custom spellings. For each key-value pair in the dictionary,
the key is the 'to' field, and the value is the 'from' field.
"""

if self._raw_transcription_config.custom_spelling is None:
return None

custom_spellings = {}
for custom_spelling in self._raw_transcription_config.custom_spelling:
_from = custom_spelling["from"]
if isinstance(_from, str):
custom_spellings[_from] = custom_spelling["to"]
else:
raise ValueError("`from` argument must be a string!")
_to = custom_spelling["to"]
if not isinstance(_to, str):
raise ValueError("`to` argument must be a string!")

custom_spellings[_to] = custom_spelling["from"]

return custom_spellings if custom_spelling else None

Expand Down Expand Up @@ -1231,13 +1234,14 @@ def set_custom_spelling(
Customize how given words are being spelled or formatted in the transcription's text.
Args:
replacement: A dictionary that contains the replacement object (see below example)
replacement: A dictionary that contains the replacement object (see below example).
For each key-value pair, the key is the 'to' field, and the value is the 'from' field.
override: If `True` `replacement` gets overriden with the given `replacement` argument, otherwise merged.
Example:
```
config.custom_spelling({
"AssemblyAI": "AssemblyAI",
"AssemblyAI": "assemblyAI",
"Kubernetes": ["k8s", "kubernetes"]
})
```
Expand Down Expand Up @@ -1619,7 +1623,7 @@ class BaseTranscript(BaseModel):
"Enable Topic Detection."

custom_spelling: Optional[List[Dict[str, Union[str, List[str]]]]] = None
"Customize how words are spelled and formatted using to and from values"
"Customize how words are spelled and formatted using to and from values."

disfluencies: Optional[bool] = None
"Transcribe Filler Words, like 'umm', in your media file."
Expand Down
93 changes: 93 additions & 0 deletions tests/unit/test_custom_spelling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import factory
from pytest_httpx import HTTPXMock

import tests.unit.unit_test_utils as unit_test_utils
import assemblyai as aai
from tests.unit import factories

aai.settings.api_key = "test"


class CustomSpellingFactory(factory.Factory):
class Meta:
model = dict # The model is a dictionary
rename = {"_from": "from"}

_from = factory.List([factory.Faker("word")]) # List of words in 'from'
to = factory.Faker("word") # one word in 'to'


class CustomSpellingResponseFactory(factories.TranscriptCompletedResponseFactory):
@factory.lazy_attribute
def custom_spelling(self):
return [CustomSpellingFactory()]


def test_custom_spelling_disabled_by_default(httpx_mock: HTTPXMock):
"""
Tests that not calling `set_custom_spelling()` on the `TranscriptionConfig`
will result in the default behavior of it being excluded from the request body.
"""
request_body, transcript = unit_test_utils.submit_mock_transcription_request(
httpx_mock,
mock_response=factories.generate_dict_factory(
factories.TranscriptCompletedResponseFactory
)(),
config=aai.TranscriptionConfig(),
)
assert request_body.get("custom_spelling") is None
assert transcript.json_response.get("custom_spelling") is None


def test_custom_spelling_set_config_succeeds():
"""
Tests that calling `set_custom_spelling()` on the `TranscriptionConfig`
will set the values correctly, and that the config values can be accessed again
through the custom_spelling property.
"""
config = aai.TranscriptionConfig()

# Setting a string will be put in a list
config.set_custom_spelling({"AssemblyAI": "assemblyAI"})
assert config.custom_spelling == {"AssemblyAI": ["assemblyAI"]}

# Setting multiple pairs works
config.set_custom_spelling(
{"AssemblyAI": "assemblyAI", "Kubernetes": ["k8s", "kubernetes"]}, override=True
)
assert config.custom_spelling == {
"AssemblyAI": ["assemblyAI"],
"Kubernetes": ["k8s", "kubernetes"],
}


def test_custom_spelling_enabled(httpx_mock: HTTPXMock):
"""
Tests that calling `set_custom_spelling()` on the `TranscriptionConfig`
will result in correct `custom_spelling` in the request body, and that the
response is properly parsed into the `custom_spelling` field.
"""

mock_response = factories.generate_dict_factory(CustomSpellingResponseFactory)()

# Set up the custom spelling config based on the mocked values
from_ = mock_response["custom_spelling"][0]["from"]
to = mock_response["custom_spelling"][0]["to"]

config = aai.TranscriptionConfig().set_custom_spelling({to: from_})

request_body, transcript = unit_test_utils.submit_mock_transcription_request(
httpx_mock,
mock_response=mock_response,
config=config,
)

# Check that request body was properly defined
custom_spelling_response = request_body["custom_spelling"]
assert custom_spelling_response is not None and len(custom_spelling_response) > 0
assert "from" in custom_spelling_response[0]
assert "to" in custom_spelling_response[0]

# Check that transcript has no errors and custom spelling response corresponds to request
assert transcript.error is None
assert transcript.json_response["custom_spelling"] == custom_spelling_response

0 comments on commit c9f35e3

Please sign in to comment.