diff --git a/src/message_db/client.py b/src/message_db/client.py index 8f0256c..ab5782d 100644 --- a/src/message_db/client.py +++ b/src/message_db/client.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List from uuid import uuid4 +from psycopg2 import DatabaseError from psycopg2.extensions import connection from psycopg2.extras import Json, RealDictCursor @@ -55,7 +56,7 @@ def _write( data: Dict[str, Any], metadata: Dict[str, Any] | None = None, expected_version: int | None = None, - ) -> int | None: + ) -> int: try: with connection.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute( @@ -74,7 +75,9 @@ def _write( ) result = cursor.fetchone() - except Exception as exc: + if result is None: + raise ValueError("No result returned from the database operation.") + except DatabaseError as exc: raise ValueError( f"{getattr(exc, 'pgcode')}-{getattr(exc, 'pgerror').splitlines()[0]}" ) from exc @@ -88,7 +91,7 @@ def write( data: Dict, metadata: Dict | None = None, expected_version: int | None = None, - ) -> int | None: + ) -> int: conn = self.connection_pool.get_connection() try: @@ -103,7 +106,7 @@ def write( def write_batch( self, stream_name, data, expected_version: int | None = None - ) -> int | None: + ) -> int: conn = self.connection_pool.get_connection() try: diff --git a/tests/test_write.py b/tests/test_write.py index 16f77c3..d35820d 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -1,4 +1,7 @@ +from unittest.mock import MagicMock + import pytest +from psycopg2.extras import RealDictCursor class TestMessageWrite: @@ -83,3 +86,23 @@ def write_msg(): messages = client.read("concurrentStream-123") assert len(messages) == 10 + + def test_write_with_no_result(self, client): + # Create a mock connection and a mock cursor + mock_connection = MagicMock() + mock_cursor = MagicMock(spec=RealDictCursor) + mock_cursor.fetchone.return_value = None # Simulate no result returned + + # Setup the mock cursor to be used when calling cursor() on the mock connection + mock_connection.cursor.return_value.__enter__.return_value = mock_cursor + + # Execute the _write method and expect a ValueError + with pytest.raises(ValueError) as exc_info: + client._write( + connection=mock_connection, + stream_name="testStream", + message_type="testType", + data={"key": "value"}, + ) + + assert "No result returned from the database operation." in str(exc_info.value)