diff --git a/README.md b/README.md index de894e9..138aca6 100644 --- a/README.md +++ b/README.md @@ -333,7 +333,8 @@ Async OK. Parameters: * `base_path` - Where to store the files on the local filesystem - * `auto_make_dir` - Automatically create the directory as needed. + * `auto_make_dir` (defualt: `False`)- Automatically create the directory as needed. + * `allow_sync_methods` (default: `True`) - When `False`, all synchronous calls throw a `RuntimeError`. Might be helpful in preventing accidentally using the sync `save`/`exists`/`delete` methods, which would block the async loop too. #### S3Handler @@ -355,6 +356,7 @@ Parameters: * `host_url` - When using [non-AWS S3 service](https://www.google.com/search?q=s3+compatible+storage) (like [Linode](https://www.linode.com/products/object-storage/)), use this url to connect. (Example: `'https://us-east-1.linodeobjects.com'`) * `region_name` - Overrides any region_name defined in the [AWS configuration file](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file) or the `AWS_DEFAULT_REGION` environment variable. Required if using AWS S3 and the value is not already set elsewhere. * `addressing_style` - Overrides any S3.addressing_style set in the [AWS configuration file](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file). + * `allow_sync_methods` (default: `True`) - When `False`, all synchronous calls throw a `RuntimeError`. Might be helpful in preventing accidentally using the sync `save`/`exists`/`delete` methods, which would block the async loop too. Permissions can be configured in three different ways. They can be stored in environment variables, then can be stored in a particular AWS file, or they can be passed in directly. diff --git a/filestorage/handler_base.py b/filestorage/handler_base.py index 05ae127..ccc97ed 100644 --- a/filestorage/handler_base.py +++ b/filestorage/handler_base.py @@ -192,6 +192,10 @@ def save_data(self, data: bytes, filename: str) -> str: class AsyncStorageHandlerBase(StorageHandlerBase, ABC): """Base class for all asynchronous storage handlers.""" + def __init__(self, allow_sync_methods=True, **kwargs): + self.allow_sync_methods = allow_sync_methods + super().__init__(**kwargs) + def validate(self) -> Optional[Awaitable]: """Validate that the configuration is set up properly and the necessary libraries are available. @@ -213,6 +217,8 @@ async def async_exists(self, filename: str) -> bool: return await self._async_exists(item) def _exists(self, item: FileItem) -> bool: + if not self.allow_sync_methods: + raise RuntimeError('Sync exists method not allowed') return utils.async_to_sync(self._async_exists)(item) @abstractmethod @@ -228,6 +234,8 @@ async def async_delete(self, filename: str) -> None: await self._async_delete(item) def _delete(self, item: FileItem) -> None: + if not self.allow_sync_methods: + raise RuntimeError('Sync delete method not allowed') utils.async_to_sync(self._async_delete)(item) @abstractmethod @@ -238,6 +246,8 @@ async def _async_delete(self, item: FileItem) -> None: pass def _save(self, item: FileItem) -> str: + if not self.allow_sync_methods: + raise RuntimeError('Sync save method not allowed') return utils.async_to_sync(self._async_save)(item) @abstractmethod diff --git a/filestorage/handler_base.pyi b/filestorage/handler_base.pyi index 0ecb5ea..a799f5d 100644 --- a/filestorage/handler_base.pyi +++ b/filestorage/handler_base.pyi @@ -39,6 +39,10 @@ class StorageHandlerBase(ABC, metaclass=abc.ABCMeta): def save_data(self, data: bytes, filename: str) -> str: ... class AsyncStorageHandlerBase(StorageHandlerBase, ABC, metaclass=abc.ABCMeta): + allow_sync_methods: Any = ... + def __init__( + self, allow_sync_methods: bool = ..., **kwargs: Any + ) -> None: ... def validate(self) -> Optional[Awaitable]: ... async def async_exists(self, filename: str) -> bool: ... async def async_delete(self, filename: str) -> None: ... diff --git a/filestorage/handlers/file.py b/filestorage/handlers/file.py index 19bafab..bab1243 100644 --- a/filestorage/handlers/file.py +++ b/filestorage/handlers/file.py @@ -38,6 +38,8 @@ def make_dir(self, item: Optional[FileItem] = None): """Ensures the provided path exists.""" if not item: item = self.get_item('') + else: + item = item.copy(filename='') local_path = self.local_path(item) if local_path in self._created_dirs: @@ -67,13 +69,14 @@ def _delete(self, item: FileItem) -> None: except FileNotFoundError: pass - def _save(self, item: FileItem) -> Optional[str]: - item.sync_seek(0) - + def _save(self, item: FileItem) -> str: if item.data is None: raise RuntimeError('No data for file {item.filename!r}') - filename = self.resolve_filename(item) + if self.auto_make_dir: + self.make_dir(item) + + item = self.resolve_filename(item) with open(self.local_path(item), 'wb') as destination: with item as f: while True: @@ -82,21 +85,23 @@ def _save(self, item: FileItem) -> Optional[str]: break destination.write(chunk) - return filename + return item.filename - def resolve_filename(self, item: FileItem) -> str: + def resolve_filename(self, item: FileItem) -> FileItem: """Ensures a unique name for this file in the folder""" if not self._exists(item): - return item.filename + return item basename, ext = os.path.splitext(item.filename) - counter = 1 - while True: + for counter in range(1, 1000000): filename = f'{basename}-{counter}{ext}' - item.copy(filename=filename) + item = item.copy(filename=filename) if not self._exists(item): - return item.filename - counter += 1 + return item + else: + raise RuntimeError( + f'Cannot get unique name for file {basename}{ext}' + ) def os_wrap(fn: utils.SyncCallable) -> utils.AsyncCallable: @@ -106,29 +111,25 @@ def os_wrap(fn: utils.SyncCallable) -> utils.AsyncCallable: return aiofiles.os.wrap(fn) # type: ignore -class AsyncLocalFileHandler(AsyncStorageHandlerBase, LocalFileHandler): - """Class for storing files locally""" +def disabled_method(*args, **kwargs): + raise RuntimeError('method not allowed') - def __init__(self, base_path, auto_make_dir=False, **kwargs): - super().__init__(**kwargs) - self.base_path = base_path - self.auto_make_dir = auto_make_dir - self._created_dirs: Set[str] = set() - def local_path(self, item: FileItem) -> str: - """Returns the local path to the file.""" - return os.path.join(self.base_path, item.fs_path) +class AsyncLocalFileHandler(LocalFileHandler, AsyncStorageHandlerBase): + """Class for storing files locally""" async def async_make_dir(self, item: Optional[FileItem] = None): """Ensures the provided path exists.""" if not item: - item = self.get_item('dummy') + item = self.get_item('') + else: + item = item.copy(filename='') local_path = self.local_path(item) if local_path in self._created_dirs: return - os_wrap(os.makedirs)(local_path, exist_ok=True) # type: ignore + await os_wrap(os.makedirs)(local_path, exist_ok=True) # type: ignore def validate(self) -> None: if aiofiles is None: @@ -136,7 +137,12 @@ def validate(self) -> None: 'The aiofiles library is required for using ' f'{self.__class__.__name__}' ) + + # Ensure the sync methods can operate while validating + temp_setting = self.allow_sync_methods + self.allow_sync_methods = True super().validate() + self.allow_sync_methods = temp_setting async def _async_exists(self, item: FileItem) -> bool: try: @@ -148,14 +154,18 @@ async def _async_exists(self, item: FileItem) -> bool: async def _async_delete(self, item: FileItem) -> None: try: - aiofiles.os.remove(self.local_path(item)) + await aiofiles.os.remove(self.local_path(item)) except FileNotFoundError: pass - async def _async_save(self, item: FileItem) -> Optional[str]: - await item.async_seek(0) + async def _async_save(self, item: FileItem) -> str: + if item.data is None: + raise RuntimeError('No data for file {item.filename!r}') + + if self.auto_make_dir: + await self.async_make_dir(item) - filename = await self.async_resolve_filename(item) + item = await self.async_resolve_filename(item) open_context = aiofiles.open(self.local_path(item), 'wb') async with open_context as destination: # type: ignore async with item as f: @@ -165,18 +175,35 @@ async def _async_save(self, item: FileItem) -> Optional[str]: break await destination.write(chunk) - return filename + return item.filename - async def async_resolve_filename(self, item: FileItem) -> str: + async def async_resolve_filename(self, item: FileItem) -> FileItem: """Ensures a unique name for this file in the folder""" if not await self._async_exists(item): - return item.filename + return item basename, ext = os.path.splitext(item.filename) - counter = 1 - while True: + for counter in range(1, 1000000): filename = f'{basename}-{counter}{ext}' - item.copy(filename=filename) + item = item.copy(filename=filename) if not await self._async_exists(item): - return item.filename - counter += 1 + return item + else: + raise RuntimeError( + f'Cannot get unique name for file {basename}{ext}' + ) + + def _save(self, item: FileItem) -> str: + if not self.allow_sync_methods: + raise RuntimeError('Sync save method not allowed') + return super()._save(item) + + def _exists(self, item: FileItem) -> bool: + if not self.allow_sync_methods: + raise RuntimeError('Sync exists method not allowed') + return super()._exists(item) + + def _delete(self, item: FileItem) -> None: + if not self.allow_sync_methods: + raise RuntimeError('Sync delete method not allowed') + super()._delete(item) diff --git a/filestorage/handlers/file.pyi b/filestorage/handlers/file.pyi index 2b6e837..b2e301d 100644 --- a/filestorage/handlers/file.pyi +++ b/filestorage/handlers/file.pyi @@ -20,17 +20,13 @@ class LocalFileHandler(StorageHandlerBase): def local_path(self, item: FileItem) -> str: ... def make_dir(self, item: Optional[FileItem] = ...) -> Any: ... def validate(self) -> None: ... - def resolve_filename(self, item: FileItem) -> str: ... + def resolve_filename(self, item: FileItem) -> FileItem: ... def os_wrap(fn: utils.SyncCallable) -> utils.AsyncCallable: ... +def disabled_method(*args: Any, **kwargs: Any) -> None: ... -class AsyncLocalFileHandler(AsyncStorageHandlerBase, LocalFileHandler): - base_path: Any = ... - auto_make_dir: Any = ... - def __init__( - self, base_path: Any, auto_make_dir: bool = ..., **kwargs: Any - ) -> None: ... - def local_path(self, item: FileItem) -> str: ... +class AsyncLocalFileHandler(LocalFileHandler, AsyncStorageHandlerBase): async def async_make_dir(self, item: Optional[FileItem] = ...) -> Any: ... + allow_sync_methods: bool = ... def validate(self) -> None: ... - async def async_resolve_filename(self, item: FileItem) -> str: ... + async def async_resolve_filename(self, item: FileItem) -> FileItem: ... diff --git a/tests/handlers/test_local_file.py b/tests/handlers/test_local_file.py new file mode 100644 index 0000000..7577a1d --- /dev/null +++ b/tests/handlers/test_local_file.py @@ -0,0 +1,282 @@ +import os +import pytest +from tempfile import TemporaryDirectory + +from filestorage import StorageContainer +from filestorage.exceptions import FilestorageConfigError +from filestorage.handlers import LocalFileHandler, AsyncLocalFileHandler + + +@pytest.fixture +def store(): + return StorageContainer() + + +@pytest.fixture +def directory(): + # Make a new directory and provide the path as a string. + # Will remove the directory when complete. + with TemporaryDirectory() as tempdir: + yield tempdir + + +def exists(directory: str, filename: str) -> bool: + """Check if the given file exists in the given directory. + It's synchronous, but it's probably fine for a test. + """ + return os.path.exists(os.path.join(directory, filename)) + + +def get_contents(directory: str, filename: str) -> bytes: + path = os.path.join(directory, filename) + with open(path, 'rb') as f: + return f.read() + + +def test_auto_create_directory(directory): + directory = os.path.join(directory, 'folder', 'subfolder') + handler = LocalFileHandler(base_path=directory, auto_make_dir=True) + + assert not os.path.exists(directory) + handler.validate() + + assert os.path.exists(directory) + + +def test_error_when_no_directory(directory): + directory = os.path.join(directory, 'folder', 'subfolder') + handler = LocalFileHandler(base_path=directory) + + with pytest.raises(FilestorageConfigError) as err: + handler.validate() + + assert directory.rstrip('/').rstrip('\\') in str(err.value) + assert 'does not exist' in str(err.value) + + +def test_local_file_handler_save(directory): + handler = LocalFileHandler(base_path=directory) + + handler.save_data(filename='test.txt', data=b'contents') + + assert exists(directory, 'test.txt') + assert get_contents(directory, 'test.txt') == b'contents' + + +def test_local_file_handler_try_save_subfolder(directory, store): + store.handler = LocalFileHandler(base_path=directory, auto_make_dir=True) + handler = store / 'folder' / 'subfolder' + + handler.save_data(filename='test.txt', data=b'contents') + + directory = os.path.join(directory, 'folder', 'subfolder') + assert exists(directory, 'test.txt') + assert get_contents(directory, 'test.txt') == b'contents' + + +def test_local_file_save_same_filename(directory): + handler = LocalFileHandler(base_path=directory) + + first = handler.save_data(filename='test.txt', data=b'contents 1') + second = handler.save_data(filename='test.txt', data=b'contents 2') + third = handler.save_data(filename='test.txt', data=b'contents 3') + + assert first == 'test.txt' + assert second == 'test-1.txt' + assert third == 'test-2.txt' + + assert exists(directory, first) + assert exists(directory, second) + assert exists(directory, third) + + assert get_contents(directory, first) == b'contents 1' + assert get_contents(directory, second) == b'contents 2' + assert get_contents(directory, third) == b'contents 3' + + +def test_local_file_handler_exists(directory): + handler = LocalFileHandler(base_path=directory) + assert not exists(directory, 'test.txt') + + handler.save_data(filename='test.txt', data=b'contents') + assert exists(directory, 'test.txt') + + +def test_local_file_handler_delete(directory): + handler = LocalFileHandler(base_path=directory) + handler.save_data(filename='test.txt', data=b'contents') + assert exists(directory, 'test.txt') + + handler.delete(filename='test.txt') + + assert not exists(directory, 'test.txt') + + +# Async tests # + + +def test_async_auto_create_directory(directory): + directory = os.path.join(directory, 'folder', 'subfolder') + handler = AsyncLocalFileHandler(base_path=directory, auto_make_dir=True) + assert not os.path.exists(directory) + + handler.validate() + + assert os.path.exists(directory) + + +def test_async_error_when_no_directory(directory): + directory = os.path.join(directory, 'folder', 'subfolder') + handler = AsyncLocalFileHandler(base_path=directory) + + with pytest.raises(FilestorageConfigError) as err: + handler.validate() + + assert directory.rstrip('/').rstrip('\\') in str(err.value) + assert 'does not exist' in str(err.value) + + +def test_async_validate_when_no_sync(directory): + directory = os.path.join(directory, 'folder', 'subfolder') + handler = AsyncLocalFileHandler( + base_path=directory, allow_sync_methods=False, auto_make_dir=True + ) + assert not os.path.exists(directory) + + handler.validate() + + assert os.path.exists(directory) + + +@pytest.mark.asyncio +async def test_async_local_file_handler_save(directory): + handler = AsyncLocalFileHandler(base_path=directory) + + await handler.async_save_data(filename='test.txt', data=b'contents') + + assert exists(directory, 'test.txt') + assert get_contents(directory, 'test.txt') == b'contents' + + +@pytest.mark.asyncio +async def test_async_local_file_handler_exists(directory): + handler = AsyncLocalFileHandler(base_path=directory) + assert not exists(directory, 'test.txt') + await handler.async_save_data(filename='test.txt', data=b'contents') + + assert exists(directory, 'test.txt') + + +@pytest.mark.asyncio +async def test_async_local_file_handler_delete(directory): + handler = AsyncLocalFileHandler(base_path=directory) + await handler.async_save_data(filename='test.txt', data=b'contents') + assert exists(directory, 'test.txt') + + await handler.async_delete(filename='test.txt') + + assert not exists(directory, 'test.txt') + + +@pytest.mark.asyncio +async def test_async_to_sync_local_file_handler_save(directory): + handler = AsyncLocalFileHandler(base_path=directory) + + handler.save_data(filename='test.txt', data=b'contents') + + assert exists(directory, 'test.txt') + assert get_contents(directory, 'test.txt') == b'contents' + + +@pytest.mark.asyncio +async def test_async_to_sync_local_file_handler_exists(directory): + handler = AsyncLocalFileHandler(base_path=directory) + assert not exists(directory, 'test.txt') + + handler.save_data(filename='test.txt', data=b'contents') + assert exists(directory, 'test.txt') + + +@pytest.mark.asyncio +async def test_async_to_sync_local_file_handler_delete(directory): + handler = AsyncLocalFileHandler(base_path=directory) + handler.save_data(filename='test.txt', data=b'contents') + assert exists(directory, 'test.txt') + + handler.delete(filename='test.txt') + + assert not exists(directory, 'test.txt') + + +@pytest.mark.asyncio +async def test_async_local_file_handler_try_save_subfolder(directory, store): + store.handler = AsyncLocalFileHandler( + base_path=directory, auto_make_dir=True + ) + handler = store / 'folder' / 'subfolder' + + await handler.async_save_data(filename='test.txt', data=b'contents') + + directory = os.path.join(directory, 'folder', 'subfolder') + assert exists(directory, 'test.txt') + assert get_contents(directory, 'test.txt') == b'contents' + + +@pytest.mark.asyncio +async def test_async_local_file_save_same_filename(directory): + handler = AsyncLocalFileHandler(base_path=directory) + + first = await handler.async_save_data( + filename='test.txt', data=b'contents 1' + ) + second = await handler.async_save_data( + filename='test.txt', data=b'contents 2' + ) + third = await handler.async_save_data( + filename='test.txt', data=b'contents 3' + ) + + assert first == 'test.txt' + assert second == 'test-1.txt' + assert third == 'test-2.txt' + + assert exists(directory, first) + assert exists(directory, second) + assert exists(directory, third) + + assert get_contents(directory, first) == b'contents 1' + assert get_contents(directory, second) == b'contents 2' + assert get_contents(directory, third) == b'contents 3' + + +def test_async_only_save(directory): + handler = AsyncLocalFileHandler( + base_path=directory, allow_sync_methods=False + ) + + with pytest.raises(RuntimeError) as err: + handler.save_data(filename='test.txt', data=b'contents') + + assert str(err.value) == 'Sync save method not allowed' + + +def test_async_only_exists(directory): + handler = AsyncLocalFileHandler( + base_path=directory, allow_sync_methods=False + ) + + with pytest.raises(RuntimeError) as err: + handler.exists(filename='test.txt') + + assert str(err.value) == 'Sync exists method not allowed' + + +def test_async_only_delete(directory): + handler = AsyncLocalFileHandler( + base_path=directory, allow_sync_methods=False + ) + + with pytest.raises(RuntimeError) as err: + handler.delete(filename='test.txt') + + assert str(err.value) == 'Sync delete method not allowed' diff --git a/tests/handlers/test_s3.py b/tests/handlers/test_s3.py index c8eb481..9806a4f 100644 --- a/tests/handlers/test_s3.py +++ b/tests/handlers/test_s3.py @@ -32,6 +32,11 @@ def handler(): return S3Handler(bucket_name='bucket') +@pytest.fixture +def async_only_handler(): + return S3Handler(bucket_name='bucket', allow_sync_methods=False) + + @pytest.mark.asyncio async def test_validate(mock_s3_resource, handler): await handler.validate() @@ -94,3 +99,33 @@ def test_delete(mock_s3_resource, handler): handler._delete(item) assert mock_s3_resource._file_object._deleted + + +# When allow_sync_methods is False, these should all throw a RuntimeError + + +def test_cant_save(async_only_handler): + item = async_only_handler.get_item('foo.txt', data=BytesIO(b'contents')) + + with pytest.raises(RuntimeError) as err: + async_only_handler._save(item) + + assert str(err.value) == 'Sync save method not allowed' + + +def test_cant_exists(async_only_handler): + item = async_only_handler.get_item('foo.txt', data=BytesIO(b'contents')) + + with pytest.raises(RuntimeError) as err: + async_only_handler._exists(item) + + assert str(err.value) == 'Sync exists method not allowed' + + +def test_cant_delete(async_only_handler): + item = async_only_handler.get_item('foo.txt', data=BytesIO(b'contents')) + + with pytest.raises(RuntimeError) as err: + async_only_handler._delete(item) + + assert str(err.value) == 'Sync delete method not allowed'