From 1abf5d9cb849bbfc9a2dd9ea1cbbea62c3ea19a7 Mon Sep 17 00:00:00 2001 From: Patrick Loeber <98830383+ploeber@users.noreply.github.com> Date: Wed, 25 Oct 2023 16:21:34 -0400 Subject: [PATCH] fix(python/sdk): Fix README for sentiment analysis example (#2472) GitOrigin-RevId: c2d97c625499679492c8d3c9083a39bc74211c83 --- README.md | 2 +- setup.py | 2 +- tests/unit/test_auto_chapters.py | 63 +------------ tests/unit/test_auto_highlights.py | 61 +----------- tests/unit/test_content_safety.py | 64 +------------ tests/unit/test_entity_detection.py | 61 +----------- tests/unit/test_iab_categories.py | 61 +----------- tests/unit/test_redact_pii.py | 91 ++++-------------- tests/unit/test_sentiment_analysis.py | 61 +----------- tests/unit/test_summarization.py | 129 +++++++++++++------------- tests/unit/unit_test_utils.py | 69 ++++++++++++++ 11 files changed, 178 insertions(+), 486 deletions(-) create mode 100644 tests/unit/unit_test_utils.py diff --git a/README.md b/README.md index 0ce362d..eb2a7de 100644 --- a/README.md +++ b/README.md @@ -460,7 +460,7 @@ for sentiment_result in transcript.sentiment_analysis: print(sentiment_result.text) print(sentiment_result.sentiment) # POSITIVE, NEUTRAL, or NEGATIVE print(sentiment_result.confidence) - print(f"Timestamp: {sentiment_result.timestamp.start} - {sentiment_result.timestamp.end}") + print(f"Timestamp: {sentiment_result.start} - {sentiment_result.end}") ``` If `speaker_labels` is also enabled, then each sentiment analysis result will also include a `speaker` field. diff --git a/setup.py b/setup.py index 6aacfbc..5ff7e4f 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name="assemblyai", - version="0.19.0", + version="0.18.0", description="AssemblyAI Python SDK", author="AssemblyAI", author_email="engineering.sdk@assemblyai.com", diff --git a/tests/unit/test_auto_chapters.py b/tests/unit/test_auto_chapters.py index 7f1caf3..8b03965 100644 --- a/tests/unit/test_auto_chapters.py +++ b/tests/unit/test_auto_chapters.py @@ -1,13 +1,9 @@ -import json -from typing import Any, Dict, Tuple - import factory -import httpx import pytest from pytest_httpx import HTTPXMock +import tests.unit.unit_test_utils as unit_test_utils import assemblyai as aai -from assemblyai.api import ENDPOINT_TRANSCRIPT from tests.unit import factories aai.settings.api_key = "test" @@ -17,57 +13,6 @@ class AutoChaptersResponseFactory(factories.TranscriptCompletedResponseFactory): chapters = factory.List([factory.SubFactory(factories.ChapterFactory)]) -def __submit_mock_request( - httpx_mock: HTTPXMock, - mock_response: Dict[str, Any], - config: aai.TranscriptionConfig, -) -> Tuple[Dict[str, Any], aai.Transcript]: - """ - Helper function to abstract mock transcriber calls with given `TranscriptionConfig`, - and perform some common assertions. - """ - - mock_transcript_id = mock_response.get("id", "mock_id") - - # Mock initial submission response (transcript is processing) - mock_processing_response = factories.generate_dict_factory( - factories.TranscriptProcessingResponseFactory - )() - - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}", - status_code=httpx.codes.OK, - method="POST", - json={ - **mock_processing_response, - "id": mock_transcript_id, # inject ID from main mock response - }, - ) - - # Mock polling-for-completeness response, with completed transcript - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{mock_transcript_id}", - status_code=httpx.codes.OK, - method="GET", - json=mock_response, - ) - - # == Make API request via SDK == - transcript = aai.Transcriber().transcribe( - data="https://example.org/audio.wav", - config=config, - ) - - # Check that submission and polling requests were made - assert len(httpx_mock.get_requests()) == 2 - - # Extract body of initial submission request - request = httpx_mock.get_requests()[0] - request_body = json.loads(request.content.decode()) - - return request_body, transcript - - def test_auto_chapters_fails_without_punctuation(httpx_mock: HTTPXMock): """ Tests whether the SDK raises an error before making a request @@ -75,7 +20,7 @@ def test_auto_chapters_fails_without_punctuation(httpx_mock: HTTPXMock): """ with pytest.raises(ValueError) as error: - __submit_mock_request( + unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response={}, # response doesn't matter, since it shouldn't occur config=aai.TranscriptionConfig( @@ -98,7 +43,7 @@ def test_auto_chapters_disabled_by_default(httpx_mock: HTTPXMock): Tests that excluding `auto_chapters` from the `TranscriptionConfig` will result in the default behavior of it being excluded from the request body """ - request_body, transcript = __submit_mock_request( + request_body, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( factories.TranscriptCompletedResponseFactory @@ -116,7 +61,7 @@ def test_auto_chapters_enabled(httpx_mock: HTTPXMock): response is properly parsed into a `Transcript` object """ mock_response = factories.generate_dict_factory(AutoChaptersResponseFactory)() - request_body, transcript = __submit_mock_request( + request_body, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=mock_response, config=aai.TranscriptionConfig(auto_chapters=True), diff --git a/tests/unit/test_auto_highlights.py b/tests/unit/test_auto_highlights.py index 0273b60..8472faf 100644 --- a/tests/unit/test_auto_highlights.py +++ b/tests/unit/test_auto_highlights.py @@ -1,12 +1,8 @@ -import json -from typing import Any, Dict, Tuple - import factory -import httpx from pytest_httpx import HTTPXMock +import tests.unit.unit_test_utils as unit_test_utils import assemblyai as aai -from assemblyai.api import ENDPOINT_TRANSCRIPT from tests.unit import factories aai.settings.api_key = "test" @@ -36,63 +32,12 @@ class AutohighlightTranscriptResponseFactory( auto_highlights_result = factory.SubFactory(AutohighlightResponseFactory) -def __submit_mock_request( - httpx_mock: HTTPXMock, - mock_response: Dict[str, Any], - config: aai.TranscriptionConfig, -) -> Tuple[Dict[str, Any], aai.Transcript]: - """ - Helper function to abstract mock transcriber calls with given `TranscriptionConfig`, - and perform some common assertions. - """ - - mock_transcript_id = mock_response.get("id", "mock_id") - - # Mock initial submission response (transcript is processing) - mock_processing_response = factories.generate_dict_factory( - factories.TranscriptProcessingResponseFactory - )() - - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}", - status_code=httpx.codes.OK, - method="POST", - json={ - **mock_processing_response, - "id": mock_transcript_id, # inject ID from main mock response - }, - ) - - # Mock polling-for-completeness response, with completed transcript - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{mock_transcript_id}", - status_code=httpx.codes.OK, - method="GET", - json=mock_response, - ) - - # == Make API request via SDK == - transcript = aai.Transcriber().transcribe( - data="https://example.org/audio.wav", - config=config, - ) - - # Check that submission and polling requests were made - assert len(httpx_mock.get_requests()) == 2 - - # Extract body of initial submission request - request = httpx_mock.get_requests()[0] - request_body = json.loads(request.content.decode()) - - return request_body, transcript - - def test_auto_highlights_disabled_by_default(httpx_mock: HTTPXMock): """ Tests that excluding `auto_highlights` from the `TranscriptionConfig` will result in the default behavior of it being excluded from the request body """ - request_body, transcript = __submit_mock_request( + request_body, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( factories.TranscriptCompletedResponseFactory @@ -112,7 +57,7 @@ def test_auto_highlights_enabled(httpx_mock: HTTPXMock): mock_response = factories.generate_dict_factory( AutohighlightTranscriptResponseFactory )() - request_body, transcript = __submit_mock_request( + request_body, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=mock_response, config=aai.TranscriptionConfig(auto_highlights=True), diff --git a/tests/unit/test_content_safety.py b/tests/unit/test_content_safety.py index e8f2f61..4b76978 100644 --- a/tests/unit/test_content_safety.py +++ b/tests/unit/test_content_safety.py @@ -1,14 +1,11 @@ -import json import random -from typing import Any, Dict, Tuple import factory -import httpx import pytest from pytest_httpx import HTTPXMock +import tests.unit.unit_test_utils as unit_test_utils import assemblyai as aai -from assemblyai.api import ENDPOINT_TRANSCRIPT from tests.unit import factories aai.settings.api_key = "test" @@ -69,63 +66,12 @@ class ContentSafetyTranscriptResponseFactory( content_safety_labels = factory.SubFactory(ContentSafetyResponseFactory) -def __submit_mock_request( - httpx_mock: HTTPXMock, - mock_response: Dict[str, Any], - config: aai.TranscriptionConfig, -) -> Tuple[Dict[str, Any], aai.Transcript]: - """ - Helper function to abstract mock transcriber calls with given `TranscriptionConfig`, - and perform some common assertions. - """ - - mock_transcript_id = mock_response.get("id", "mock_id") - - # Mock initial submission response (transcript is processing) - mock_processing_response = factories.generate_dict_factory( - factories.TranscriptProcessingResponseFactory - )() - - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}", - status_code=httpx.codes.OK, - method="POST", - json={ - **mock_processing_response, - "id": mock_transcript_id, # inject ID from main mock response - }, - ) - - # Mock polling-for-completeness response, with completed transcript - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{mock_transcript_id}", - status_code=httpx.codes.OK, - method="GET", - json=mock_response, - ) - - # == Make API request via SDK == - transcript = aai.Transcriber().transcribe( - data="https://example.org/audio.wav", - config=config, - ) - - # Check that submission and polling requests were made - assert len(httpx_mock.get_requests()) == 2 - - # Extract body of initial submission request - request = httpx_mock.get_requests()[0] - request_body = json.loads(request.content.decode()) - - return request_body, transcript - - def test_content_safety_disabled_by_default(httpx_mock: HTTPXMock): """ Tests that excluding `content_safety` from the `TranscriptionConfig` will result in the default behavior of it being excluded from the request body """ - request_body, transcript = __submit_mock_request( + request_body, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( factories.TranscriptCompletedResponseFactory @@ -145,7 +91,7 @@ def test_content_safety_enabled(httpx_mock: HTTPXMock): mock_response = factories.generate_dict_factory( ContentSafetyTranscriptResponseFactory )() - request_body, transcript = __submit_mock_request( + request_body, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=mock_response, config=aai.TranscriptionConfig(content_safety=True), @@ -248,7 +194,7 @@ def test_content_safety_with_confidence_threshold(httpx_mock: HTTPXMock): and will be included in the request body """ confidence = 40 - request, _ = __submit_mock_request( + request, _ = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response={}, # Response doesn't matter here; we're just testing the request body config=aai.TranscriptionConfig( @@ -269,7 +215,7 @@ def test_content_safety_with_invalid_confidence_threshold( an exception to be raised before the request is sent """ with pytest.raises(ValueError) as error: - __submit_mock_request( + unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response={}, # We don't expect to produce a response config=aai.TranscriptionConfig( diff --git a/tests/unit/test_entity_detection.py b/tests/unit/test_entity_detection.py index c069a32..73e1a61 100644 --- a/tests/unit/test_entity_detection.py +++ b/tests/unit/test_entity_detection.py @@ -1,12 +1,8 @@ -import json -from typing import Any, Dict, Tuple - import factory -import httpx from pytest_httpx import HTTPXMock +import tests.unit.unit_test_utils as unit_test_utils import assemblyai as aai -from assemblyai.api import ENDPOINT_TRANSCRIPT from tests.unit import factories aai.settings.api_key = "test" @@ -26,63 +22,12 @@ class EntityDetectionResponseFactory(factories.TranscriptCompletedResponseFactor entities = factory.List([factory.SubFactory(EntityFactory)]) -def __submit_mock_request( - httpx_mock: HTTPXMock, - mock_response: Dict[str, Any], - config: aai.TranscriptionConfig, -) -> Tuple[Dict[str, Any], aai.Transcript]: - """ - Helper function to abstract mock transcriber calls with given `TranscriptionConfig`, - and perform some common assertions. - """ - - mock_transcript_id = mock_response.get("id", "mock_id") - - # Mock initial submission response (transcript is processing) - mock_processing_response = factories.generate_dict_factory( - factories.TranscriptProcessingResponseFactory - )() - - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}", - status_code=httpx.codes.OK, - method="POST", - json={ - **mock_processing_response, - "id": mock_transcript_id, # inject ID from main mock response - }, - ) - - # Mock polling-for-completeness response, with completed transcript - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{mock_transcript_id}", - status_code=httpx.codes.OK, - method="GET", - json=mock_response, - ) - - # == Make API request via SDK == - transcript = aai.Transcriber().transcribe( - data="https://example.org/audio.wav", - config=config, - ) - - # Check that submission and polling requests were made - assert len(httpx_mock.get_requests()) == 2 - - # Extract body of initial submission request - request = httpx_mock.get_requests()[0] - request_body = json.loads(request.content.decode()) - - return request_body, transcript - - def test_entity_detection_disabled_by_default(httpx_mock: HTTPXMock): """ Tests that excluding `entity_detection` from the `TranscriptionConfig` will result in the default behavior of it being excluded from the request body """ - request_body, transcript = __submit_mock_request( + request_body, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( factories.TranscriptCompletedResponseFactory @@ -100,7 +45,7 @@ def test_entity_detection_enabled(httpx_mock: HTTPXMock): response is properly parsed into a `Transcript` object """ mock_response = factories.generate_dict_factory(EntityDetectionResponseFactory)() - request_body, transcript = __submit_mock_request( + request_body, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=mock_response, config=aai.TranscriptionConfig(entity_detection=True), diff --git a/tests/unit/test_iab_categories.py b/tests/unit/test_iab_categories.py index 3b8ae75..f8d6ae4 100644 --- a/tests/unit/test_iab_categories.py +++ b/tests/unit/test_iab_categories.py @@ -1,12 +1,8 @@ -import json -from typing import Any, Dict, Tuple - import factory -import httpx from pytest_httpx import HTTPXMock +import tests.unit.unit_test_utils as unit_test_utils import assemblyai as aai -from assemblyai.api import ENDPOINT_TRANSCRIPT from tests.unit import factories aai.settings.api_key = "test" @@ -48,64 +44,13 @@ class IABCategoriesResponseFactory(factories.TranscriptCompletedResponseFactory) iab_categories_result = factory.SubFactory(IABResponseFactory) -def __submit_mock_request( - httpx_mock: HTTPXMock, - mock_response: Dict[str, Any], - config: aai.TranscriptionConfig, -) -> Tuple[Dict[str, Any], aai.Transcript]: - """ - Helper function to abstract mock transcriber calls with given `TranscriptionConfig`, - and perform some common assertions. - """ - - mock_transcript_id = mock_response.get("id", "mock_id") - - # Mock initial submission response (transcript is processing) - mock_processing_response = factories.generate_dict_factory( - factories.TranscriptProcessingResponseFactory - )() - - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}", - status_code=httpx.codes.OK, - method="POST", - json={ - **mock_processing_response, - "id": mock_transcript_id, # inject ID from main mock response - }, - ) - - # Mock polling-for-completeness response, with completed transcript - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{mock_transcript_id}", - status_code=httpx.codes.OK, - method="GET", - json=mock_response, - ) - - # == Make API request via SDK == - transcript = aai.Transcriber().transcribe( - data="https://example.org/audio.wav", - config=config, - ) - - # Check that submission and polling requests were made - assert len(httpx_mock.get_requests()) == 2 - - # Extract body of initial submission request - request = httpx_mock.get_requests()[0] - request_body = json.loads(request.content.decode()) - - return request_body, transcript - - def test_iab_categories_disabled_by_default(httpx_mock: HTTPXMock): """ Tests that excluding `iab_categories` from the `TranscriptionConfig` will result in the default behavior of it being excluded from the request body """ - request_body, transcript = __submit_mock_request( + request_body, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( factories.TranscriptCompletedResponseFactory @@ -125,7 +70,7 @@ def test_iab_categories_enabled(httpx_mock: HTTPXMock): mock_response = factories.generate_dict_factory(IABCategoriesResponseFactory)() - request_body, transcript = __submit_mock_request( + request_body, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=mock_response, config=aai.TranscriptionConfig(iab_categories=True), diff --git a/tests/unit/test_redact_pii.py b/tests/unit/test_redact_pii.py index 8bffdd4..c1db80f 100644 --- a/tests/unit/test_redact_pii.py +++ b/tests/unit/test_redact_pii.py @@ -1,11 +1,9 @@ -import json -from typing import Any, Dict, Tuple - import httpx import pytest from pytest_httpx import HTTPXMock from pytest_mock import MockerFixture +import tests.unit.unit_test_utils as unit_test_utils import assemblyai as aai from assemblyai.api import ENDPOINT_TRANSCRIPT from tests.unit import factories @@ -23,63 +21,12 @@ class TranscriptWithPIIRedactionResponseFactory( ] -def __submit_mock_request( - httpx_mock: HTTPXMock, - mock_response: Dict[str, Any], - config: aai.TranscriptionConfig, -) -> Tuple[Dict[str, Any], aai.Transcript]: - """ - Helper function to abstract mock transcriber calls with given `TranscriptionConfig`, - and perform some common assertions. - """ - - mock_transcript_id = mock_response.get("id", "mock_id") - - # Mock initial submission response (transcript is processing) - mock_processing_response = factories.generate_dict_factory( - factories.TranscriptProcessingResponseFactory - )() - - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}", - status_code=httpx.codes.OK, - method="POST", - json={ - **mock_processing_response, - "id": mock_transcript_id, # inject ID from main mock response - }, - ) - - # Mock polling-for-completeness response, with completed transcript - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{mock_transcript_id}", - status_code=httpx.codes.OK, - method="GET", - json=mock_response, - ) - - # == Make API request via SDK == - transcript = aai.Transcriber().transcribe( - data="https://example.org/audio.wav", - config=config, - ) - - # Check that submission and polling requests were made - assert len(httpx_mock.get_requests()) == 2 - - # Extract body of initial submission request - request = httpx_mock.get_requests()[0] - request_body = json.loads(request.content.decode()) - - return request_body, transcript - - def test_redact_pii_disabled_by_default(httpx_mock: HTTPXMock): """ Tests that excluding `redact_pii` from the `TranscriptionConfig` will result in the default behavior of it being excluded from the request body """ - request_body, _ = __submit_mock_request( + request_body, _ = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( factories.TranscriptCompletedResponseFactory @@ -102,7 +49,7 @@ def test_redact_pii_enabled(httpx_mock: HTTPXMock): aai.types.PIIRedactionPolicy.phone_number, ] - request_body, _ = __submit_mock_request( + request_body, _ = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( TranscriptWithPIIRedactionResponseFactory @@ -129,7 +76,7 @@ def test_redact_pii_enabled_with_optional_params(httpx_mock: HTTPXMock): ] sub_type = aai.types.PIISubstitutionPolicy.entity_name - request_body, _ = __submit_mock_request( + request_body, _ = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( TranscriptWithPIIRedactionResponseFactory @@ -154,7 +101,7 @@ def test_redact_pii_fails_without_policies(httpx_mock: HTTPXMock): will result in an exception being raised before the API call is made """ with pytest.raises(ValueError) as error: - __submit_mock_request( + unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response={}, config=aai.TranscriptionConfig( @@ -177,7 +124,7 @@ def test_redact_pii_params_excluded_when_disabled(httpx_mock: HTTPXMock): Tests that additional PII redaction parameters are excluded from the submission request body if `redact_pii` itself is not enabled. """ - request_body, _ = __submit_mock_request( + request_body, _ = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( factories.TranscriptCompletedResponseFactory @@ -244,7 +191,7 @@ def test_get_pii_redacted_audio_url(httpx_mock: HTTPXMock): Tests that the PII-redacted audio URL can be retrieved from the API with a successful response """ - _, transcript = __submit_mock_request( + _, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( TranscriptWithPIIRedactionResponseFactory @@ -273,11 +220,11 @@ def test_get_pii_redacted_audio_url_fails_if_redact_pii_not_enabled_for_transcri `redact_pii` was not enabled for the transcript and `get_redacted_audio_url` is called """ - _, transcript = __submit_mock_request( + _, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( - factories.TranscriptCompletedResponseFactory # standard response - )(), + factories.TranscriptCompletedResponseFactory + )(), # standard response config=aai.TranscriptionConfig(), # blank config ) @@ -298,7 +245,7 @@ def test_get_pii_redacted_audio_url_fails_if_redact_pii_audio_not_enabled_for_tr `redact_pii_audio` was not enabled for the transcript and `get_redacted_audio_url` is called """ - _, transcript = __submit_mock_request( + _, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response={ **factories.generate_dict_factory( @@ -326,7 +273,7 @@ def test_get_pii_redacted_audio_url_fails_if_bad_response(httpx_mock: HTTPXMock) the request to fetch the redacted audio URL returns a `400` status code, indicating that the redacted audio has expired """ - _, transcript = __submit_mock_request( + _, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( TranscriptWithPIIRedactionResponseFactory @@ -349,7 +296,7 @@ def test_save_pii_redacted_audio(httpx_mock: HTTPXMock, mocker: MockerFixture): to the caller's file system """ - _, transcript = __submit_mock_request( + _, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( TranscriptWithPIIRedactionResponseFactory @@ -398,11 +345,11 @@ def test_save_pii_redacted_audio_fails_if_redact_pii_not_enabled_for_transcript( `redact_pii` was not enabled for the transcript and `save_redacted_audio` is called """ - _, transcript = __submit_mock_request( + _, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( - factories.TranscriptCompletedResponseFactory # standard response - )(), + factories.TranscriptCompletedResponseFactory + )(), # standard response config=aai.TranscriptionConfig(), # blank config ) @@ -423,7 +370,7 @@ def test_save_pii_redacted_audio_fails_if_redact_pii_audio_not_enabled_for_trans `redact_pii_audio` was not enabled for the transcript and `get_redacted_audio_url` is called """ - _, transcript = __submit_mock_request( + _, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response={ **factories.generate_dict_factory( @@ -451,7 +398,7 @@ def test_save_pii_redacted_audio_fails_if_bad_response(httpx_mock: HTTPXMock): the request to fetch the redacted audio URL returns a `400` status code, indicating that the redacted audio has expired """ - _, transcript = __submit_mock_request( + _, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( TranscriptWithPIIRedactionResponseFactory @@ -473,7 +420,7 @@ def test_save_pii_redacted_audio_fails_if_bad_audio_url_response(httpx_mock: HTT Tests that `save_redacted_audio` raises a `RedactedAudioUnavailableError` if the request to fetch the redacted audio **file** returns a non-200 status code """ - _, transcript = __submit_mock_request( + _, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( TranscriptWithPIIRedactionResponseFactory diff --git a/tests/unit/test_sentiment_analysis.py b/tests/unit/test_sentiment_analysis.py index f0a8440..aeeda95 100644 --- a/tests/unit/test_sentiment_analysis.py +++ b/tests/unit/test_sentiment_analysis.py @@ -1,12 +1,8 @@ -import json -from typing import Any, Dict, Tuple - import factory -import httpx from pytest_httpx import HTTPXMock +import tests.unit.unit_test_utils as unit_test_utils import assemblyai as aai -from assemblyai.api import ENDPOINT_TRANSCRIPT from tests.unit import factories aai.settings.api_key = "test" @@ -21,63 +17,12 @@ class SentimentAnalysisResponseFactory(factories.TranscriptCompletedResponseFact sentiment_analysis_results = factory.List([factory.SubFactory(SentimentFactory)]) -def __submit_mock_request( - httpx_mock: HTTPXMock, - mock_response: Dict[str, Any], - config: aai.TranscriptionConfig, -) -> Tuple[Dict[str, Any], aai.Transcript]: - """ - Helper function to abstract mock transcriber calls with given `TranscriptionConfig`, - and perform some common assertions. - """ - - mock_transcript_id = mock_response.get("id", "mock_id") - - # Mock initial submission response (transcript is processing) - mock_processing_response = factories.generate_dict_factory( - factories.TranscriptProcessingResponseFactory - )() - - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}", - status_code=httpx.codes.OK, - method="POST", - json={ - **mock_processing_response, - "id": mock_transcript_id, # inject ID from main mock response - }, - ) - - # Mock polling-for-completeness response, with completed transcript - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{mock_transcript_id}", - status_code=httpx.codes.OK, - method="GET", - json=mock_response, - ) - - # == Make API request via SDK == - transcript = aai.Transcriber().transcribe( - data="https://example.org/audio.wav", - config=config, - ) - - # Check that submission and polling requests were made - assert len(httpx_mock.get_requests()) == 2 - - # Extract body of initial submission request - request = httpx_mock.get_requests()[0] - request_body = json.loads(request.content.decode()) - - return request_body, transcript - - def test_sentiment_analysis_disabled_by_default(httpx_mock: HTTPXMock): """ Tests that excluding `sentiment_analysis` from the `TranscriptionConfig` will result in the default behavior of it being excluded from the request body """ - request_body, transcript = __submit_mock_request( + request_body, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=factories.generate_dict_factory( factories.TranscriptCompletedResponseFactory @@ -95,7 +40,7 @@ def test_sentiment_analysis_enabled(httpx_mock: HTTPXMock): response is properly parsed into a `Transcript` object """ mock_response = factories.generate_dict_factory(SentimentAnalysisResponseFactory)() - request_body, transcript = __submit_mock_request( + request_body, transcript = unit_test_utils.submit_mock_transcription_request( httpx_mock, mock_response=mock_response, config=aai.TranscriptionConfig(sentiment_analysis=True), diff --git a/tests/unit/test_summarization.py b/tests/unit/test_summarization.py index 15fb2ee..53d5e55 100644 --- a/tests/unit/test_summarization.py +++ b/tests/unit/test_summarization.py @@ -1,63 +1,17 @@ -import json -from typing import Any, Dict - -import httpx +import factory import pytest from pytest_httpx import HTTPXMock +import tests.unit.factories as factories +import tests.unit.unit_test_utils as test_utils import assemblyai as aai -from assemblyai.api import ENDPOINT_TRANSCRIPT from tests.unit import factories aai.settings.api_key = "test" -def __submit_request(httpx_mock: HTTPXMock, **params) -> Dict[str, Any]: - """ - Helper function to abstract calling transcriber with given parameters, - and perform some common assertions. - - Returns the body (dictionary) of the initial submission request. - """ - summary = "example summary" - - mock_transcript_response = factories.generate_dict_factory( - factories.TranscriptCompletedResponseFactory - )() - - # Mock initial submission response - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}", - status_code=httpx.codes.OK, - method="POST", - json=mock_transcript_response, - ) - - # Mock polling-for-completeness response, with mock summary result - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{mock_transcript_response['id']}", - status_code=httpx.codes.OK, - method="GET", - json={**mock_transcript_response, "summary": summary}, - ) - - # == Make API request via SDK == - transcript = aai.Transcriber().transcribe( - data="https://example.org/audio.wav", - config=aai.TranscriptionConfig( - **params, - ), - ) - - # Check that submission and polling requests were made - assert len(httpx_mock.get_requests()) == 2 - - # Check that summary field from response was traced back through SDK classes - assert transcript.summary == summary - - # Extract and return body of initial submission request - request = httpx_mock.get_requests()[0] - return json.loads(request.content.decode()) +class SummarizationResponseFactory(factories.TranscriptCompletedResponseFactory): + summary = factory.Faker("sentence") @pytest.mark.parametrize("required_field", ["punctuate", "format_text"]) @@ -69,7 +23,13 @@ def test_summarization_fails_without_required_field( if `summarization` is enabled and the given required field is disabled """ with pytest.raises(ValueError) as error: - __submit_request(httpx_mock, summarization=True, **{required_field: False}) + test_utils.submit_mock_transcription_request( + httpx_mock, + {}, + config=aai.TranscriptionConfig( + summarization=True, **{required_field: False} # type: ignore + ), + ) # Check that the error message informs the user of the invalid parameter assert required_field in str(error) @@ -86,17 +46,41 @@ def test_summarization_disabled_by_default(httpx_mock: HTTPXMock): Tests that excluding `summarization` from the `TranscriptionConfig` will result in the default behavior of it being excluded from the request body """ - request_body = __submit_request(httpx_mock) + mock_response = factories.generate_dict_factory( + factories.TranscriptCompletedResponseFactory + )() + request_body, transcript = test_utils.submit_mock_transcription_request( + httpx_mock, + mock_response, + config=aai.TranscriptionConfig(), + ) + + # Check that request body was properly defined assert request_body.get("summarization") is None + # Check that transcript was properly parsed from JSON response + assert transcript.error is None + assert transcript.summary is None + def test_default_summarization_params(httpx_mock: HTTPXMock): """ Tests that including `summarization=True` in the `TranscriptionConfig` will result in `summarization=True` in the request body. """ - request_body = __submit_request(httpx_mock, summarization=True) + mock_response = factories.generate_dict_factory(SummarizationResponseFactory)() + request_body, transcript = test_utils.submit_mock_transcription_request( + httpx_mock, mock_response, aai.TranscriptionConfig(summarization=True) + ) + + # Check that request body was properly defined assert request_body.get("summarization") == True + assert request_body.get("summary_model") == None + assert request_body.get("summary_type") == None + + # Check that transcript was properly parsed from JSON response + assert transcript.error is None + assert transcript.summary == mock_response["summary"] def test_summarization_with_params(httpx_mock: HTTPXMock): @@ -109,30 +93,51 @@ def test_summarization_with_params(httpx_mock: HTTPXMock): summary_model = aai.SummarizationModel.conversational summary_type = aai.SummarizationType.bullets - request_body = __submit_request( + mock_response = factories.generate_dict_factory(SummarizationResponseFactory)() + + request_body, transcript = test_utils.submit_mock_transcription_request( httpx_mock, - summarization=True, - summary_model=summary_model, - summary_type=summary_type, + mock_response, + aai.TranscriptionConfig( + summarization=True, + summary_model=summary_model, + summary_type=summary_type, + ), ) + # Check that request body was properly defined assert request_body.get("summarization") == True assert request_body.get("summary_model") == summary_model assert request_body.get("summary_type") == summary_type + # Check that transcript was properly parsed from JSON response + assert transcript.error is None + assert transcript.summary == mock_response["summary"] + def test_summarization_params_excluded_when_disabled(httpx_mock: HTTPXMock): """ Tests that additional summarization parameters are excluded from the submission request body if `summarization` itself is not enabled. """ - request_body = __submit_request( + mock_response = factories.generate_dict_factory( + factories.TranscriptCompletedResponseFactory + )() + request_body, transcript = test_utils.submit_mock_transcription_request( httpx_mock, - summarization=False, - summary_model=aai.SummarizationModel.conversational, - summary_type=aai.SummarizationType.bullets, + mock_response, + aai.TranscriptionConfig( + summarization=False, + summary_model=aai.SummarizationModel.conversational, + summary_type=aai.SummarizationType.bullets, + ), ) + # Check that request body was properly defined assert request_body.get("summarization") is None assert request_body.get("summary_model") is None assert request_body.get("summary_type") is None + + # Check that transcript was properly parsed from JSON response + assert transcript.error is None + assert transcript.summary is None diff --git a/tests/unit/unit_test_utils.py b/tests/unit/unit_test_utils.py new file mode 100644 index 0000000..968b52b --- /dev/null +++ b/tests/unit/unit_test_utils.py @@ -0,0 +1,69 @@ +import json +from typing import Any, Dict, Tuple + +import httpx +from pytest_httpx import HTTPXMock + +import assemblyai as aai +from assemblyai.api import ENDPOINT_TRANSCRIPT +from tests.unit import factories + + +def submit_mock_transcription_request( + httpx_mock: HTTPXMock, + mock_response: Dict[str, Any], + config: aai.TranscriptionConfig, +) -> Tuple[Dict[str, Any], aai.transcriber.Transcript]: + """ + Helper function to abstract calling transcriber with given parameters, + and perform some common assertions. + + Args: + httpx_mock: HTTPXMock instance to use for mocking requests + mock_response: Dict to use as mock response from API + config: The `TranscriptionConfig` to use for transcription + + Returns: + A tuple containing the JSON body of the initial submission request, + and the `Transcript` object parsed from the mock response + """ + + mock_transcript_id = mock_response.get("id", "mock_id") + + # Mock initial submission response (transcript is processing) + mock_processing_response = factories.generate_dict_factory( + factories.TranscriptProcessingResponseFactory + )() + + httpx_mock.add_response( + url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}", + status_code=httpx.codes.OK, + method="POST", + json={ + **mock_processing_response, + "id": mock_transcript_id, # inject ID from main mock response + }, + ) + + # Mock polling-for-completeness response, with completed transcript + httpx_mock.add_response( + url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{mock_transcript_id}", + status_code=httpx.codes.OK, + method="GET", + json=mock_response, + ) + + # == Make API request via SDK == + transcript = aai.Transcriber().transcribe( + data="https://example.org/audio.wav", + config=config, + ) + + # Check that submission and polling requests were made + assert len(httpx_mock.get_requests()) == 2 + + # Extract body of initial submission request + request = httpx_mock.get_requests()[0] + request_body = json.loads(request.content.decode()) + + return request_body, transcript