Skip to content

Commit

Permalink
Add client and server support for SFTP copy-data extension
Browse files Browse the repository at this point in the history
This commit adds client and server support for the SFTP "copy-data"
extension, and a new remote_copy() method on SFTPClient wihch allows you
to make a request to copy bytes between two files on the remote server
without needing to download and re-upload the data, if the server
supports it.

Thanks go to Ali Khosravi for suggesting this addition.
  • Loading branch information
ronf committed Nov 30, 2024
1 parent 4917d8d commit 743966d
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 19 deletions.
138 changes: 134 additions & 4 deletions asyncssh/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@
MAX_SFTP_WRITE_LEN = 4*1024*1024 # 4 MiB
MAX_SFTP_PACKET_LEN = MAX_SFTP_WRITE_LEN + 1024

_COPY_DATA_BLOCK_SIZE = 256*1024 # 256 KiB

_MAX_SFTP_REQUESTS = 128
_MAX_READDIR_NAMES = 128

Expand Down Expand Up @@ -806,6 +808,24 @@ async def run(self) -> None:
if self._progress_handler and self._total_bytes == 0:
self._progress_handler(self._srcpath, self._dstpath, 0, 0)

if self._srcfs == self._dstfs and \
isinstance(self._srcfs, SFTPClient):
try:
await self._srcfs.remote_copy(
cast(SFTPClientFile, self._src),
cast(SFTPClientFile, self._dst))
except SFTPOpUnsupported:
pass
else:
self._bytes_copied = self._total_bytes

if self._progress_handler:
self._progress_handler(self._srcpath, self._dstpath,
self._bytes_copied,
self._total_bytes)

return

async for _, datalen in self.iter():
if datalen:
self._bytes_copied += datalen
Expand All @@ -822,8 +842,6 @@ async def run(self) -> None:
setattr(exc, 'offset', self._bytes_copied)

raise exc


finally:
if self._src: # pragma: no branch
await self._src.close()
Expand Down Expand Up @@ -2472,6 +2490,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop,
self._supports_fsync = False
self._supports_lsetstat = False
self._supports_limits = False
self._supports_copy_data = False

@property
def version(self) -> int:
Expand Down Expand Up @@ -2692,6 +2711,8 @@ async def start(self) -> None:
self._supports_lsetstat = True
elif name == b'[email protected]' and data == b'1':
self._supports_limits = True
elif name == b'copy-data' and data == b'1':
self._supports_copy_data = True

if version == 3:
# Check if the server has a buggy SYMLINK implementation
Expand Down Expand Up @@ -3090,6 +3111,26 @@ async def fsync(self, handle: bytes) -> None:
else:
raise SFTPOpUnsupported('fsync not supported')

async def copy_data(self, read_from_handle: bytes, read_from_offset: int,
read_from_length: int, write_to_handle: bytes,
write_to_offset: int) -> None:
"""Make an SFTP copy data request"""

if self._supports_copy_data:
self.logger.debug1('Sending copy-data from handle %s, '
'offset %d, length %d to handle %s, '
'offset %d', read_from_handle.hex(),
read_from_offset, read_from_length,
write_to_handle.hex(), write_to_offset)

await self._make_request(b'copy-data', String(read_from_handle),
UInt64(read_from_offset),
UInt64(read_from_length),
String(write_to_handle),
UInt64(write_to_offset))
else:
raise SFTPOpUnsupported('copy-data not supported')

def exit(self) -> None:
"""Handle a request to close the SFTP session"""

Expand Down Expand Up @@ -3142,6 +3183,15 @@ async def __aexit__(self, _exc_type: Optional[Type[BaseException]],
await self.close()
return False

@property
def handle(self) -> bytes:
"""Return handle or raise an error if clsoed"""

if self._handle is None:
raise ValueError('I/O operation on closed file')

return self._handle

async def _end(self) -> int:
"""Return the offset of the end of the file"""

Expand Down Expand Up @@ -4233,6 +4283,35 @@ async def mcopy(self, srcpaths: _SFTPPaths,
block_size, max_requests, progress_handler,
error_handler)

async def remote_copy(self, src: SFTPClientFile, dst: SFTPClientFile,
src_offset: int = 0, src_length: int = 0,
dst_offset: int = 0) -> None:
"""Copy data between remote files
:param src:
The remote file object to read data from
:param dst:
The remote file object to write data to
:param src_offset: (optional)
The offset to begin reading data from
:param src_length: (optional)
The number of bytes to attempt to copy
:param dst_offset: (optional)
The offset to begin writing data to
:type src: :class:`SSHClientFile`
:type dst: :class:`SSHClientFile`
:type src_offset: `int`
:type src_length: `int`
:type dst_offset: `int`
:raises: :exc:`SFTPError` if the server doesn't support this
extension or returns an error
"""

await self._handler.copy_data(src.handle, src_offset, src_length,
dst.handle, dst_offset)

async def glob(self, patterns: _SFTPPaths,
error_handler: SFTPErrorHandler = None) -> \
Sequence[BytesOrStr]:
Expand Down Expand Up @@ -5583,7 +5662,8 @@ class SFTPServerHandler(SFTPHandler):
(b'[email protected]', b'1'),
(b'[email protected]', b'1'),
(b'[email protected]', b'1'),
(b'[email protected]', b'1')]
(b'[email protected]', b'1'),
(b'copy-data', b'1')]

_attrib_extensions: List[bytes] = []

Expand Down Expand Up @@ -6437,6 +6517,55 @@ async def _process_limits(self, packet: SSHPacket) -> SFTPLimits:
return SFTPLimits(MAX_SFTP_PACKET_LEN, MAX_SFTP_READ_LEN,
MAX_SFTP_WRITE_LEN, nfiles)

async def _process_copy_data(self, packet: SSHPacket) -> None:
"""Process an incoming copy data request"""

read_from_handle = packet.get_string()
read_from_offset = packet.get_uint64()
read_from_length = packet.get_uint64()
write_to_handle = packet.get_string()
write_to_offset = packet.get_uint64()
packet.check_end()

self.logger.debug1('Received copy-data from handle %s, '
'offset %d, length %d to handle %s, '
'offset %d', read_from_handle.hex(),
read_from_offset, read_from_length,
write_to_handle.hex(), write_to_offset)

src = self._file_handles.get(read_from_handle)
dst = self._file_handles.get(write_to_handle)

if src and dst:
read_to_end = read_from_length == 0

while read_to_end or read_from_length:
if read_to_end:
size = _COPY_DATA_BLOCK_SIZE
else:
size = min(read_from_length, _COPY_DATA_BLOCK_SIZE)

data = self._server.read(src, read_from_offset, size)

if inspect.isawaitable(data):
data = await cast(Awaitable[bytes], data)

result = self._server.write(dst, write_to_offset, data)

if inspect.isawaitable(result):
await result

if len(data) < size:
break

read_from_offset += size
write_to_offset += size

if not read_to_end:
read_from_length -= size
else:
raise SFTPInvalidHandle('Invalid file handle')

_packet_handlers: Dict[Union[int, bytes], _SFTPPacketHandler] = {
FXP_OPEN: _process_open,
FXP_CLOSE: _process_close,
Expand Down Expand Up @@ -6465,7 +6594,8 @@ async def _process_limits(self, packet: SSHPacket) -> SFTPLimits:
b'[email protected]': _process_openssh_link,
b'[email protected]': _process_fsync,
b'[email protected]': _process_lsetstat,
b'[email protected]': _process_limits
b'[email protected]': _process_limits,
b'copy-data': _process_copy_data
}

async def run(self) -> None:
Expand Down
7 changes: 4 additions & 3 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1077,16 +1077,17 @@ SFTP Support
.. autoattribute:: limits
======================================================================= =

===================== =
=========================== =
File transfer methods
===================== =
=========================== =
.. automethod:: get
.. automethod:: put
.. automethod:: copy
.. automethod:: mget
.. automethod:: mput
.. automethod:: mcopy
===================== =
.. automethod:: remote_copy
=========================== =

============================================================================================================================================================================================================================== =
File access methods
Expand Down
85 changes: 73 additions & 12 deletions tests/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,26 @@ async def test_copy(self, sftp):
finally:
remove('src dst')

def test_copy_non_remote(self):
"""Test copying without using remote_copy function"""

@sftp_test
async def _test_copy_non_remote(self, sftp):
"""Test copying without using remote_copy function"""

for src in ('src', b'src', Path('src')):
with self.subTest(src=type(src)):
try:
self._create_file('src')
await sftp.copy(src, 'dst')
self._check_file('src', 'dst')
finally:
remove('src dst')

with patch('asyncssh.sftp.SFTPServerHandler._extensions', []):
# pylint: disable=no-value-for-parameter
_test_copy_non_remote(self)

@sftp_test
async def test_copy_progress(self, sftp):
"""Test copying a file over SFTP with progress reporting"""
Expand All @@ -769,7 +789,9 @@ def _report_progress(_srcpath, _dstpath, bytes_copied, _total_bytes):
progress_handler=_report_progress)
self._check_file('src', 'dst')

self.assertEqual(len(reports), (size // 8192) + 1)
if method != 'copy':
self.assertEqual(len(reports), (size // 8192) + 1)

self.assertEqual(reports[-1], size)
finally:
remove('src dst')
Expand Down Expand Up @@ -1130,6 +1152,37 @@ def err_handler(exc):
finally:
remove('src1 src2 dst')

@sftp_test
async def test_remote_copy_arguments(self, sftp):
"""Test remote copy arguments"""

try:
self._create_file('src', os.urandom(2*1024*1024))

async with sftp.open('src', 'rb') as src:
async with sftp.open('dst', 'wb') as dst:
await sftp.remote_copy(src, dst, 0, 1024*1024, 0)
await sftp.remote_copy(src, dst, 1024*1024, 0, 1024*1024)

self._check_file('src', 'dst')
finally:
remove('src dst')

@sftp_test
async def test_remote_copy_closed_file(self, sftp):
"""Test remote copy of a closed file"""

try:
self._create_file('file')

async with sftp.open('file', 'rb') as f:
await f.close()

with self.assertRaises(ValueError):
await sftp.remote_copy(f, f)
finally:
remove('file')

@sftp_test
async def test_glob(self, sftp):
"""Test a glob pattern match over SFTP"""
Expand Down Expand Up @@ -3173,6 +3226,9 @@ async def _return_invalid_handle(self, path, pflags, attrs):
with self.assertRaises(SFTPFailure):
await f.fsync()

with self.assertRaises(SFTPFailure):
await sftp.remote_copy(f, f)

with self.assertRaises(SFTPFailure):
await f.close()

Expand Down Expand Up @@ -4300,19 +4356,24 @@ async def start_server(cls):

return await cls.create_server(sftp_factory=_IOErrorSFTPServer)

@sftp_test
async def test_put_error(self, sftp):
"""Test error when putting a file to an SFTP server"""
def test_copy_error(self):
"""Test error when copying a file on an SFTP server"""

for method in ('get', 'put', 'copy'):
with self.subTest(method=method):
try:
self._create_file('src', 8*1024*1024*'\0')
@sftp_test
async def _test_copy_error(self, sftp):
"""Test error when copying a file on an SFTP server"""

with self.assertRaises(SFTPFailure):
await getattr(sftp, method)('src', 'dst')
finally:
remove('src dst')
try:
self._create_file('src', 8*1024*1024*'\0')

with self.assertRaises(SFTPFailure):
await sftp.copy('src', 'dst')
finally:
remove('src dst')

with patch('asyncssh.sftp.SFTPServerHandler._extensions', []):
# pylint: disable=no-value-for-parameter
_test_copy_error(self)

@sftp_test
async def test_read_error(self, sftp):
Expand Down

0 comments on commit 743966d

Please sign in to comment.