From 8fb95239df0b118bb5b241f438be8173b318c12c Mon Sep 17 00:00:00 2001 From: Patrick Loeber <98830383+ploeber@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:24:10 +0100 Subject: [PATCH] feat(python/sdk): Add info about HTTP response status code (#7653) GitOrigin-RevId: c6fc606f037f04da42872e72797fe66ca441b19e --- README.md | 26 +++++ assemblyai/__version__.py | 2 +- assemblyai/api.py | 56 ++++++--- assemblyai/client.py | 21 +++- assemblyai/transcriber.py | 201 +++++++++++++++++++++++++-------- assemblyai/types.py | 4 + tests/unit/test_transcriber.py | 9 +- 7 files changed, 249 insertions(+), 70 deletions(-) diff --git a/README.md b/README.md index da61e85..1a87b2c 100644 --- a/README.md +++ b/README.md @@ -937,6 +937,32 @@ The asynchronous approach allows the application to continue running while the t You can identify those two approaches by the `_async` suffix in the `Transcriber`'s method name (e.g. `transcribe` vs `transcribe_async`). +## Getting the HTTP status code + +There are two ways of accessing the HTTP status code: + +- All custom AssemblyAI Error classes have a `status_code` attribute. +- The latest HTTP response is stored in `aai.Client.get_default().latest_response` after every API call. This approach works also if no Exception is thrown. + +```python +transcriber = aai.Transcriber() + +# Option 1: Catch the error +try: + transcript = transcriber.submit("./example.mp3") +except aai.AssemblyAIError as e: + print(e.status_code) + +# Option 2: Access the latest response through the client +client = aai.Client.get_default() + +try: + transcript = transcriber.submit("./example.mp3") +except: + print(client.last_response) + print(client.last_response.status_code) +``` + ## Polling Intervals By default we poll the `Transcript`'s status each `3s`. In case you would like to adjust that interval: diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py index 98bb08f..d9f2629 100644 --- a/assemblyai/__version__.py +++ b/assemblyai/__version__.py @@ -1 +1 @@ -__version__ = "0.35.1" +__version__ = "0.36.0" diff --git a/assemblyai/api.py b/assemblyai/api.py index 6c16f1d..b2f666a 100644 --- a/assemblyai/api.py +++ b/assemblyai/api.py @@ -43,7 +43,8 @@ def create_transcript( ) if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to transcribe url {request.audio_url}: {_get_error_message(response)}" + f"failed to transcribe url {request.audio_url}: {_get_error_message(response)}", + response.status_code, ) return types.TranscriptResponse.parse_obj(response.json()) @@ -60,6 +61,7 @@ def get_transcript( if response.status_code != httpx.codes.OK: raise types.TranscriptError( f"failed to retrieve transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return types.TranscriptResponse.parse_obj(response.json()) @@ -76,6 +78,7 @@ def delete_transcript( if response.status_code != httpx.codes.OK: raise types.TranscriptError( f"failed to delete transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return types.TranscriptResponse.parse_obj(response.json()) @@ -102,7 +105,8 @@ def upload_file( if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"Failed to upload audio file: {_get_error_message(response)}" + f"Failed to upload audio file: {_get_error_message(response)}", + response.status_code, ) return response.json()["upload_url"] @@ -127,7 +131,8 @@ def export_subtitles_srt( if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to export SRT for transcript {transcript_id}: {_get_error_message(response)}" + f"failed to export SRT for transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return response.text @@ -152,7 +157,8 @@ def export_subtitles_vtt( if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to export VTT for transcript {transcript_id}: {_get_error_message(response)}" + f"failed to export VTT for transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return response.text @@ -174,7 +180,8 @@ def word_search( if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to search words in transcript {transcript_id}: {_get_error_message(response)}" + f"failed to search words in transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return types.WordSearchMatchResponse.parse_obj(response.json()) @@ -199,17 +206,20 @@ def get_redacted_audio( if response.status_code == httpx.codes.ACCEPTED: raise types.RedactedAudioIncompleteError( - f"redacted audio for transcript {transcript_id} is not ready yet" + f"redacted audio for transcript {transcript_id} is not ready yet", + response.status_code, ) if response.status_code == httpx.codes.BAD_REQUEST: raise types.RedactedAudioExpiredError( - f"redacted audio for transcript {transcript_id} is no longer available" + f"redacted audio for transcript {transcript_id} is no longer available", + response.status_code, ) if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to retrieve redacted audio for transcript {transcript_id}: {_get_error_message(response)}" + f"failed to retrieve redacted audio for transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return types.RedactedAudioResponse.parse_obj(response.json()) @@ -225,7 +235,8 @@ def get_sentences( if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to retrieve sentences for transcript {transcript_id}: {_get_error_message(response)}" + f"failed to retrieve sentences for transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return types.SentencesResponse.parse_obj(response.json()) @@ -241,7 +252,8 @@ def get_paragraphs( if response.status_code != httpx.codes.OK: raise types.TranscriptError( - f"failed to retrieve paragraphs for transcript {transcript_id}: {_get_error_message(response)}" + f"failed to retrieve paragraphs for transcript {transcript_id}: {_get_error_message(response)}", + response.status_code, ) return types.ParagraphsResponse.parse_obj(response.json()) @@ -264,7 +276,8 @@ def list_transcripts( if response.status_code != httpx.codes.OK: raise types.AssemblyAIError( - f"failed to retrieve transcripts: {_get_error_message(response)}" + f"failed to retrieve transcripts: {_get_error_message(response)}", + response.status_code, ) return types.ListTranscriptResponse.parse_obj(response.json()) @@ -285,7 +298,8 @@ def lemur_question( if response.status_code != httpx.codes.OK: raise types.LemurError( - f"failed to call Lemur questions: {_get_error_message(response)}" + f"failed to call Lemur questions: {_get_error_message(response)}", + response.status_code, ) return types.LemurQuestionResponse.parse_obj(response.json()) @@ -306,7 +320,8 @@ def lemur_summarize( if response.status_code != httpx.codes.OK: raise types.LemurError( - f"failed to call Lemur summary: {_get_error_message(response)}" + f"failed to call Lemur summary: {_get_error_message(response)}", + response.status_code, ) return types.LemurSummaryResponse.parse_obj(response.json()) @@ -327,7 +342,8 @@ def lemur_action_items( if response.status_code != httpx.codes.OK: raise types.LemurError( - f"failed to call Lemur action items: {_get_error_message(response)}" + f"failed to call Lemur action items: {_get_error_message(response)}", + response.status_code, ) return types.LemurActionItemsResponse.parse_obj(response.json()) @@ -348,7 +364,8 @@ def lemur_task( if response.status_code != httpx.codes.OK: raise types.LemurError( - f"failed to call Lemur task: {_get_error_message(response)}" + f"failed to call Lemur task: {_get_error_message(response)}", + response.status_code, ) return types.LemurTaskResponse.parse_obj(response.json()) @@ -366,7 +383,8 @@ def lemur_purge_request_data( if response.status_code != httpx.codes.OK: raise types.LemurError( - f"Failed to purge LeMUR request data for provided request ID: {request.request_id}. Error: {_get_error_message(response)}" + f"Failed to purge LeMUR request data for provided request ID: {request.request_id}. Error: {_get_error_message(response)}", + response.status_code, ) return types.LemurPurgeResponse.parse_obj(response.json()) @@ -387,7 +405,8 @@ def lemur_get_response_data( if response.status_code != httpx.codes.OK: raise types.LemurError( - f"Failed to get LeMUR response data for provided request ID: {request_id}. Error: {_get_error_message(response)}" + f"Failed to get LeMUR response data for provided request ID: {request_id}. Error: {_get_error_message(response)}", + response.status_code, ) json_data = response.json() @@ -411,7 +430,8 @@ def create_temporary_token( if response.status_code != httpx.codes.OK: raise types.AssemblyAIError( - f"Failed to create temporary token: {_get_error_message(response)}" + f"Failed to create temporary token: {_get_error_message(response)}", + response.status_code, ) data = types.RealtimeCreateTemporaryTokenResponse.parse_obj(response.json()) diff --git a/assemblyai/client.py b/assemblyai/client.py index a7d4697..e8c7c09 100644 --- a/assemblyai/client.py +++ b/assemblyai/client.py @@ -3,7 +3,6 @@ from typing import ClassVar, Optional import httpx -from typing_extensions import Self from . import types from .__version__ import __version__ @@ -41,14 +40,30 @@ def __init__( headers = {"user-agent": user_agent} if self._settings.api_key: - headers["authorization"] = self.settings.api_key + headers["authorization"] = self._settings.api_key + + self._last_response: Optional[httpx.Response] = None + + def _store_response(response): + self._last_response = response self._http_client = httpx.Client( base_url=self.settings.base_url, headers=headers, timeout=self.settings.http_timeout, + event_hooks={"response": [_store_response]}, ) + @property + def last_response(self) -> Optional[httpx.Response]: + """ + Get the last HTTP response, corresponding to the last request sent from this client. + + Returns: + The last HTTP response. + """ + return self._last_response + @property def settings(self) -> types.Settings: """ @@ -72,7 +87,7 @@ def http_client(self) -> httpx.Client: return self._http_client @classmethod - def get_default(cls, api_key_required: bool = True) -> Self: + def get_default(cls, api_key_required: bool = True): """ Return the default client. diff --git a/assemblyai/transcriber.py b/assemblyai/transcriber.py index c07e84c..9103cf7 100644 --- a/assemblyai/transcriber.py +++ b/assemblyai/transcriber.py @@ -17,6 +17,7 @@ Iterator, List, Optional, + Set, Tuple, Union, ) @@ -47,6 +48,10 @@ def __init__( @property def config(self) -> types.TranscriptionConfig: "Returns the configuration from the internal Transcript object" + if self.transcript is None: + raise ValueError( + "Canot access the configuration. The internal Transcript object is None." + ) return types.TranscriptionConfig( **self.transcript.dict( @@ -74,6 +79,10 @@ def wait_for_completion(self) -> Self: """ polls the given transcript until we have a status other than `processing` or `queued` """ + if not self.transcript_id: + raise ValueError( + "Cannot wait for completion. The internal transcript ID is None." + ) while True: # No try-except - if there is an HTTP error then surface it to user @@ -97,6 +106,11 @@ def export_subtitles_srt( *, chars_per_caption: Optional[int], ) -> str: + if not self.transcript or not self.transcript.id: + raise ValueError( + "Cannot export subtitles. The internal Transcript object is None." + ) + return api.export_subtitles_srt( client=self._client.http_client, transcript_id=self.transcript.id, @@ -108,6 +122,11 @@ def export_subtitles_vtt( *, chars_per_caption: Optional[int], ) -> str: + if not self.transcript or not self.transcript.id: + raise ValueError( + "Cannot export subtitles. The internal Transcript object is None." + ) + return api.export_subtitles_vtt( client=self._client.http_client, transcript_id=self.transcript.id, @@ -119,6 +138,11 @@ def word_search( *, words: List[str], ) -> List[types.WordSearchMatch]: + if not self.transcript or not self.transcript.id: + raise ValueError( + "Cannot perform word search. The internal Transcript object is None." + ) + response = api.word_search( client=self._client.http_client, transcript_id=self.transcript.id, @@ -128,6 +152,11 @@ def word_search( return response.matches def get_sentences(self) -> List[types.Sentence]: + if not self.transcript or not self.transcript.id: + raise ValueError( + "Cannot get sentences. The internal Transcript object is None." + ) + response = api.get_sentences( client=self._client.http_client, transcript_id=self.transcript.id, @@ -136,6 +165,11 @@ def get_sentences(self) -> List[types.Sentence]: return response.sentences def get_paragraphs(self) -> List[types.Paragraph]: + if not self.transcript or not self.transcript.id: + raise ValueError( + "Cannot get paragraphs. The internal Transcript object is None." + ) + response = api.get_paragraphs( client=self._client.http_client, transcript_id=self.transcript.id, @@ -156,6 +190,11 @@ def get_redacted_audio_url(self) -> str: "Redacted audio is only available when `redact_pii` and `redact_pii_audio` are set to `True`." ) + if not self.transcript_id: + raise ValueError( + "Cannot get redacted audio url. The internal transcript ID is None." + ) + while True: try: return api.get_redacted_audio( @@ -175,7 +214,8 @@ def save_redacted_audio(self, filepath: str): with httpx.stream(method="GET", url=self.get_redacted_audio_url()) as response: if response.status_code not in (httpx.codes.OK, httpx.codes.NOT_MODIFIED): raise types.RedactedAudioUnavailableError( - f"Fetching redacted audio failed with status code {response.status_code}" + f"Fetching redacted audio failed with status code {response.status_code}", + response.status_code, ) with open(filepath, "wb") as f: for chunk in response.iter_bytes(): @@ -183,8 +223,10 @@ def save_redacted_audio(self, filepath: str): @classmethod def delete_by_id(cls, transcript_id: str) -> types.Transcript: - client = _client.Client.get_default().http_client - response = api.delete_transcript(client=client, transcript_id=transcript_id) + client = _client.Client.get_default() + response = api.delete_transcript( + client=client.http_client, transcript_id=transcript_id + ) return Transcript.from_response(client=client, response=response) @@ -305,83 +347,112 @@ def config(self) -> types.TranscriptionConfig: @property def json_response(self) -> Optional[dict]: "The full JSON response associated with the transcript." + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.dict() @property def audio_url(self) -> str: "The corresponding audio url" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.audio_url @property def speech_model(self) -> Optional[str]: "The speech model used for the transcription" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") + return self._impl.transcript.speech_model @property def text(self) -> Optional[str]: "The text transcription of your media file" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.text @property def summary(self) -> Optional[str]: "The summarization of the transcript" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.summary @property def chapters(self) -> Optional[List[types.Chapter]]: "The list of auto-chapters results" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.chapters @property def content_safety(self) -> Optional[types.ContentSafetyResponse]: "The results from the content safety analysis" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.content_safety_labels @property def sentiment_analysis(self) -> Optional[List[types.Sentiment]]: "The list of sentiment analysis results" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.sentiment_analysis_results @property def entities(self) -> Optional[List[types.Entity]]: "The list of entity detection results" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.entities @property def iab_categories(self) -> Optional[types.IABResponse]: "The results from the IAB category detection" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.iab_categories_result @property def auto_highlights(self) -> Optional[types.AutohighlightResponse]: "The results from the auto-highlights model" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.auto_highlights_result @property def status(self) -> types.TranscriptStatus: "The current status of the transcript" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.status @property def error(self) -> Optional[str]: "The error message in case the transcription fails" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.error @property def words(self) -> Optional[List[types.Word]]: "The list of words in the transcript" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.words @@ -391,30 +462,40 @@ def utterances(self) -> Optional[List[types.Utterance]]: When `dual_channel` or `speaker_labels` is enabled, a list of utterances in the transcript. """ + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.utterances @property def confidence(self) -> Optional[float]: "The confidence our model has in the transcribed text, between 0 and 1" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.confidence @property def audio_duration(self) -> Optional[int]: "The duration of the audio in seconds" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.audio_duration @property def webhook_status_code(self) -> Optional[int]: "The status code we received from your server when delivering your webhook" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.webhook_status_code @property def webhook_auth(self) -> Optional[bool]: "Whether the webhook was sent with an HTTP authentication header" + if not self._impl.transcript: + raise ValueError("The internal Transcript object is None.") return self._impl.transcript.webhook_auth @@ -539,7 +620,11 @@ def __init__( @property def transcript_ids(self) -> List[str]: - return [t.id for t in self.transcripts] + if any(t.id is None for t in self.transcripts): + raise ValueError("All transcripts must have a transcript ID.") + return [ + t.id for t in self.transcripts if t.id + ] # include the if check for mypy type checker def add_transcript(self, transcript: Union[Transcript, str]) -> None: if isinstance(transcript, Transcript): @@ -554,17 +639,17 @@ def add_transcript(self, transcript: Union[Transcript, str]) -> None: else: raise TypeError("Unsupported type for `transcript`") - return self - - def wait_for_completion(self, return_failures) -> Union[None, List[str]]: + def wait_for_completion( + self, return_failures + ) -> Union[None, List[types.AssemblyAIError]]: transcripts: List[Transcript] = [] - failures: List[str] = [] + failures: List[types.AssemblyAIError] = [] - future_transcripts: Dict[concurrent.futures.Future[Transcript], str] = {} + future_transcripts: Set[concurrent.futures.Future[Transcript]] = set() for transcript in self.transcripts: future = transcript.wait_for_completion_async() - future_transcripts[future] = transcript + future_transcripts.add(future) finished_futures, _ = concurrent.futures.wait(future_transcripts) @@ -572,12 +657,13 @@ def wait_for_completion(self, return_failures) -> Union[None, List[str]]: try: transcripts.append(future.result()) except types.TranscriptError as e: - failures.append(str(e)) + failures.append(e) self.transcripts = transcripts - if return_failures: + if return_failures is True: return failures + return None class TranscriptGroup: @@ -616,13 +702,17 @@ def __iter__(self) -> Iterator[Transcript]: return iter(self.transcripts) @classmethod - def get_by_ids(cls, transcript_ids: List[str]) -> Self: + def get_by_ids( + cls, transcript_ids: List[str] + ) -> Union[Self, Tuple[Self, List[types.AssemblyAIError]]]: return cls(transcript_ids=transcript_ids).wait_for_completion() @classmethod def get_by_ids_async( cls, transcript_ids: List[str] - ) -> concurrent.futures.Future[Self]: + ) -> concurrent.futures.Future[ + Union[Self, Tuple[Self, List[types.AssemblyAIError]]] + ]: return cls(transcript_ids=transcript_ids).wait_for_completion_async() @property @@ -643,6 +733,8 @@ def status(self) -> types.TranscriptStatus: return types.TranscriptStatus.processing elif all(s == types.TranscriptStatus.completed for s in all_status): return types.TranscriptStatus.completed + else: + raise ValueError(f"Unexpected status type: {all_status}") @property def lemur(self) -> lemur.Lemur: @@ -672,7 +764,7 @@ def add_transcript( def wait_for_completion( self, return_failures: Optional[bool] = False, - ) -> Union[Self, Tuple[Self, List[str]]]: + ) -> Union[Self, Tuple[Self, List[types.AssemblyAIError]]]: """ Polls each transcript within the `TranscriptGroup`. @@ -682,8 +774,10 @@ def wait_for_completion( Args: return_failures: Whether to return a list of errors for transcripts that failed due to HTTP errors. """ - if return_failures: + if return_failures is True: failures = self._impl.wait_for_completion(return_failures=return_failures) + if failures is None: + raise ValueError("return_failures was set but failures object is None") return self, failures self._impl.wait_for_completion(return_failures=return_failures) @@ -693,9 +787,8 @@ def wait_for_completion( def wait_for_completion_async( self, return_failures: Optional[bool] = False, - ) -> Union[ - concurrent.futures.Future[Self], - concurrent.futures.Future[Tuple[Self, List[str]]], + ) -> concurrent.futures.Future[ + Union[Self, Tuple[Self, List[types.AssemblyAIError]]], ]: return self._executor.submit( self.wait_for_completion, return_failures=return_failures @@ -799,11 +892,11 @@ def transcribe_group( config: Optional[types.TranscriptionConfig], poll: bool, return_failures: Optional[bool] = False, - ) -> Union[TranscriptGroup, Tuple[TranscriptGroup, List[str]]]: + ) -> Union[TranscriptGroup, Tuple[TranscriptGroup, List[types.AssemblyAIError]]]: if config is None: config = self.config - future_transcripts: Dict[concurrent.futures.Future[Transcript], str] = {} + future_transcripts: Set[concurrent.futures.Future[Transcript]] = set() with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: for d in data: @@ -814,32 +907,38 @@ def transcribe_group( poll=False, ) - future_transcripts[transcript_future] = d + future_transcripts.add(transcript_future) finished_futures, _ = concurrent.futures.wait(future_transcripts) transcript_group = TranscriptGroup( client=self._client, ) - failures = [] + failures: List[types.AssemblyAIError] = [] for future in finished_futures: try: transcript_group.add_transcript(future.result()) except types.TranscriptError as e: - failures.append(f"Error processing {future_transcripts[future]}: {e}") + failures.append(e) - if poll and return_failures: - transcript_group, completion_failures = ( - transcript_group.wait_for_completion(return_failures=return_failures) - ) + if poll is True and return_failures is True: + res = transcript_group.wait_for_completion(return_failures=return_failures) + if not isinstance(res, tuple): + raise ValueError( + "return_failures was set but did not receive failures object" + ) + transcript_group, completion_failures = res failures.extend(completion_failures) elif poll: - transcript_group = transcript_group.wait_for_completion( - return_failures=return_failures - ) + res = transcript_group.wait_for_completion(return_failures=return_failures) + if not isinstance(res, TranscriptGroup): + raise ValueError( + "return_failures was not set but did receive failures object" + ) + transcript_group = res - if return_failures: + if return_failures is True: return transcript_group, failures else: return transcript_group @@ -895,7 +994,11 @@ def __init__( ) if not max_workers: - max_workers = max(1, os.cpu_count() - 1) + cpu_count = os.cpu_count() + if not cpu_count: + max_workers = 1 + else: + max_workers = max(1, cpu_count - 1) self._executor = concurrent.futures.ThreadPoolExecutor( max_workers=max_workers, @@ -969,7 +1072,7 @@ def submit_group( data: List[Union[str, BinaryIO]], config: Optional[types.TranscriptionConfig] = None, return_failures: Optional[bool] = False, - ) -> Union[TranscriptGroup, Tuple[TranscriptGroup, List[str]]]: + ) -> Union[TranscriptGroup, Tuple[TranscriptGroup, List[types.AssemblyAIError]]]: """ Submits multiple transcription jobs without waiting for their completion. @@ -1032,7 +1135,7 @@ def transcribe_group( data: List[Union[str, BinaryIO]], config: Optional[types.TranscriptionConfig] = None, return_failures: Optional[bool] = False, - ) -> Union[TranscriptGroup, Tuple[TranscriptGroup, List[str]]]: + ) -> Union[TranscriptGroup, Tuple[TranscriptGroup, List[types.AssemblyAIError]]]: """ Transcribes a list of files (as local paths, URLs, or binary objects). @@ -1055,9 +1158,8 @@ def transcribe_group_async( data: List[Union[str, BinaryIO]], config: Optional[types.TranscriptionConfig] = None, return_failures: Optional[bool] = False, - ) -> Union[ - concurrent.futures.Future[TranscriptGroup], - concurrent.futures.Future[Tuple[TranscriptGroup, List[str]]], + ) -> concurrent.futures.Future[ + Union[TranscriptGroup, Tuple[TranscriptGroup, List[types.AssemblyAIError]]] ]: """ Transcribes a list of files (as local paths, URLs, or binary objects) asynchronously. @@ -1247,7 +1349,8 @@ def close(self, terminate: bool = False) -> None: try: self._read_thread.join() self._write_thread.join() - self._websocket.close() + if self._websocket: + self._websocket.close() except Exception: pass @@ -1262,15 +1365,18 @@ def _read(self) -> None: """ while not self._stop_event.is_set(): + if not self._websocket: + raise ValueError("Websocket is None") + try: - message = self._websocket.recv(timeout=1) + recv_message = self._websocket.recv(timeout=1) except TimeoutError: continue except websockets.exceptions.ConnectionClosed as exc: return self._handle_error(exc) try: - message = json.loads(message) + message = json.loads(recv_message) except json.JSONDecodeError as exc: self._on_error( types.RealtimeError( @@ -1295,7 +1401,9 @@ def _write(self) -> None: continue try: - if isinstance(data, dict): + if not self._websocket: + raise ValueError("websocket is None") + elif isinstance(data, dict): self._websocket.send(json.dumps(data)) elif isinstance(data, bytes): self._websocket.send(data) @@ -1333,9 +1441,10 @@ def _handle_message( message["message_type"] == types.RealtimeMessageTypes.session_information ): - self._on_extra_session_information( - types.RealtimeSessionInformation(**message) - ) + if self._on_extra_session_information is not None: + self._on_extra_session_information( + types.RealtimeSessionInformation(**message) + ) elif "error" in message: self._on_error(types.RealtimeError(message["error"])) @@ -1358,7 +1467,7 @@ def _handle_error(self, error: websockets.exceptions.ConnectionClosed) -> None: error_message = error.reason if error.code != 1000: - self._on_error(types.RealtimeError(error_message)) + self._on_error(types.RealtimeError(error_message, error.code)) self.close() diff --git a/assemblyai/types.py b/assemblyai/types.py index 93fb761..447cbbe 100644 --- a/assemblyai/types.py +++ b/assemblyai/types.py @@ -27,6 +27,10 @@ class AssemblyAIError(Exception): Base exception for all AssemblyAI errors """ + def __init__(self, message: str, status_code: Optional[int] = None): + super().__init__(message) + self.status_code = status_code + class TranscriptError(AssemblyAIError): """ diff --git a/tests/unit/test_transcriber.py b/tests/unit/test_transcriber.py index b40a7dc..ed26124 100644 --- a/tests/unit/test_transcriber.py +++ b/tests/unit/test_transcriber.py @@ -70,6 +70,7 @@ def test_upload_file_fails(httpx_mock: HTTPXMock): # check wheter the TranscriptError contains the specified error message assert returned_error_message in str(excinfo.value) + assert httpx.codes.INTERNAL_SERVER_ERROR == excinfo.value.status_code def test_submit_url_succeeds(httpx_mock: HTTPXMock): @@ -120,6 +121,7 @@ def test_submit_url_fails(httpx_mock: HTTPXMock): transcriber.submit("https://example.org/audio.wav") assert "something went wrong" in str(excinfo) + assert httpx.codes.INTERNAL_SERVER_ERROR == excinfo.value.status_code # check whether we mocked everything assert len(httpx_mock.get_requests()) == 1 @@ -148,6 +150,7 @@ def test_submit_file_fails_due_api_error(httpx_mock: HTTPXMock): # check wheter the Exception contains the specified error message assert "something went wrong" in str(excinfo.value) + assert httpx.codes.INTERNAL_SERVER_ERROR == excinfo.value.status_code # check whether we mocked everything assert len(httpx_mock.get_requests()) == 1 @@ -430,7 +433,8 @@ def test_transcribe_group_urls_fails_during_upload(httpx_mock: HTTPXMock): assert len(failures) == 1 # Check whether the error message corresponds to the raised TranscriptError message - assert f"Error processing {expect_failed_audio_url}" in failures[0] + assert "failed to transcribe url" in str(failures[0]) + assert failures[0].status_code == httpx.codes.INTERNAL_SERVER_ERROR def test_transcribe_group_urls_fails_during_polling(httpx_mock: HTTPXMock): @@ -501,7 +505,8 @@ def test_transcribe_group_urls_fails_during_polling(httpx_mock: HTTPXMock): assert len(failures) == 1 # Check whether the error message is correct - assert "failed to retrieve transcript" in failures[0] + assert "failed to retrieve transcript" in str(failures[0]) + assert failures[0].status_code == httpx.codes.INTERNAL_SERVER_ERROR def test_transcribe_async_url_succeeds(httpx_mock: HTTPXMock):