diff --git a/asyncssh/sftp.py b/asyncssh/sftp.py index 4b8f0ca..209bee3 100644 --- a/asyncssh/sftp.py +++ b/asyncssh/sftp.py @@ -3759,7 +3759,8 @@ async def _copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, preserve: bool, recurse: bool, follow_symlinks: bool, block_size: int, max_requests: int, progress_handler: SFTPProgressHandler, - error_handler: SFTPErrorHandler) -> None: + error_handler: SFTPErrorHandler, + remote_only: bool) -> None: """Copy a file, directory, or symbolic link""" try: @@ -3795,7 +3796,8 @@ async def _copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, await self._copy(srcfs, dstfs, srcfile, dstfile, srcname.attrs, preserve, recurse, follow_symlinks, block_size, max_requests, - progress_handler, error_handler) + progress_handler, error_handler, + remote_only) self.logger.info(' Finished copy of directory %s to %s', srcpath, dstpath) @@ -3810,6 +3812,9 @@ async def _copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, else: self.logger.info(' Copying file %s to %s', srcpath, dstpath) + if remote_only and not self.supports_remote_copy: + raise SFTPOpUnsupported('Remote copy not supported') + await _SFTPFileCopier(block_size, max_requests, 0, srcattrs.size or 0, srcfs, dstfs, srcpath, dstpath, progress_handler).run() @@ -3846,7 +3851,8 @@ async def _begin_copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, recurse: bool, follow_symlinks: bool, block_size: int, max_requests: int, progress_handler: SFTPProgressHandler, - error_handler: SFTPErrorHandler) -> None: + error_handler: SFTPErrorHandler, + remote_only: bool = False) -> None: """Begin a new file upload, download, or copy""" if block_size <= 0: @@ -3903,7 +3909,8 @@ async def _begin_copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, await self._copy(srcfs, dstfs, srcfile, dstfile, srcname.attrs, preserve, recurse, follow_symlinks, block_size, - max_requests, progress_handler, error_handler) + max_requests, progress_handler, error_handler, + remote_only) async def get(self, remotepaths: _SFTPPaths, localpath: Optional[_SFTPPath] = None, *, @@ -4222,13 +4229,10 @@ async def copy(self, srcpaths: _SFTPPaths, """ - if remote_only and not self.supports_remote_copy: - raise SFTPOpUnsupported('Remote copy not supported') - await self._begin_copy(self, self, srcpaths, dstpath, 'remote copy', False, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, - error_handler) + error_handler, remote_only) async def mget(self, remotepaths: _SFTPPaths, localpath: Optional[_SFTPPath] = None, *, @@ -4295,13 +4299,10 @@ async def mcopy(self, srcpaths: _SFTPPaths, """ - if remote_only and not self.supports_remote_copy: - raise SFTPOpUnsupported('Remote copy not supported') - await self._begin_copy(self, self, srcpaths, dstpath, 'remote mcopy', True, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, - error_handler) + error_handler, remote_only) async def remote_copy(self, src: _SFTPClientFileOrPath, dst: _SFTPClientFileOrPath, src_offset: int = 0, diff --git a/tests/test_sftp.py b/tests/test_sftp.py index 59b377d..a853508 100644 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -777,9 +777,14 @@ async def _test_copy_remote_only(self, sftp): for method in ('copy', 'mcopy'): with self.subTest(method=method): - with self.assertRaises(SFTPOpUnsupported): - await getattr(sftp, method)('src', 'dst', - remote_only=True) + try: + self._create_file('src') + + with self.assertRaises(SFTPOpUnsupported): + await getattr(sftp, method)('src', 'dst', + remote_only=True) + finally: + remove('src') with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): # pylint: disable=no-value-for-parameter