diff --git a/asyncssh/config.py b/asyncssh/config.py index 18ff3c8..a50888a 100644 --- a/asyncssh/config.py +++ b/asyncssh/config.py @@ -425,7 +425,7 @@ class SSHClientConfig(SSHConfig): """Settings from an OpenSSH client config file""" _conditionals = {'host', 'match'} - _no_split = {'remotecommand'} + _no_split = {'proxycommand', 'remotecommand'} _percent_expand = {'CertificateFile', 'IdentityAgent', 'IdentityFile', 'ProxyCommand', 'RemoteCommand'} @@ -559,7 +559,7 @@ def _set_tokens(self) -> None: ('PKCS11Provider', SSHConfig._set_string), ('PreferredAuthentications', SSHConfig._set_string), ('Port', SSHConfig._set_int), - ('ProxyCommand', SSHConfig._set_string_list), + ('ProxyCommand', SSHConfig._set_string), ('ProxyJump', SSHConfig._set_string), ('PubkeyAuthentication', SSHConfig._set_bool), ('RekeyLimit', SSHConfig._set_rekey_limits), diff --git a/asyncssh/connection.py b/asyncssh/connection.py index 9c05c5e..517fce4 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -115,7 +115,7 @@ from .misc import TermModesArg, TermSizeArg from .misc import async_context_manager, construct_disc_error from .misc import get_symbol_names, ip_address, map_handler_name -from .misc import parse_byte_count, parse_time_interval +from .misc import parse_byte_count, parse_time_interval, split_args from .packet import Boolean, Byte, NameList, String, UInt32, PacketDecodeError from .packet import SSHPacket, SSHPacketHandler, SSHPacketLogger @@ -231,7 +231,7 @@ async def create_server(self, session_factory: TCPListenerFactory, _GlobalRequestResult = Tuple[int, SSHPacket] _KeyOrCertOptions = Mapping[str, object] _ListenerArg = Union[bool, SSHListener] -_ProxyCommand = Optional[Sequence[str]] +_ProxyCommand = Optional[Union[str, Sequence[str]]] _RequestPTY = Union[bool, str] _TCPServerHandlerFactory = Callable[[str, int], SSHSocketSessionFactory] @@ -7144,11 +7144,13 @@ def prepare(self, config: SSHConfig, # type: ignore self.tunnel = tunnel if tunnel != () else config.get('ProxyJump') self.passphrase = passphrase + if proxy_command == (): + proxy_command = cast(Optional[str], config.get('ProxyCommand')) + if isinstance(proxy_command, str): - proxy_command = shlex.split(proxy_command) + proxy_command = split_args(proxy_command) - self.proxy_command = proxy_command if proxy_command != () else \ - cast(Sequence[str], config.get('ProxyCommand')) + self.proxy_command = proxy_command self.family = cast(int, family if family != () else config.get('AddressFamily', socket.AF_UNSPEC)) @@ -9224,7 +9226,7 @@ async def create_server(server_factory: _ServerFactory, async def get_server_host_key( host = '', port: DefTuple[int] = (), *, tunnel: DefTuple[_TunnelConnector] = (), - proxy_command: DefTuple[str] = (), family: DefTuple[int] = (), + proxy_command: DefTuple[_ProxyCommand] = (), family: DefTuple[int] = (), flags: int = 0, local_addr: DefTuple[HostPort] = (), sock: Optional[socket.socket] = None, client_version: DefTuple[BytesOrStr] = (), @@ -9368,7 +9370,7 @@ def conn_factory() -> SSHClientConnection: async def get_server_auth_methods( host = '', port: DefTuple[int] = (), username: DefTuple[str] = (), *, tunnel: DefTuple[_TunnelConnector] = (), - proxy_command: DefTuple[str] = (), family: DefTuple[int] = (), + proxy_command: DefTuple[_ProxyCommand] = (), family: DefTuple[int] = (), flags: int = 0, local_addr: DefTuple[HostPort] = (), sock: Optional[socket.socket] = None, client_version: DefTuple[BytesOrStr] = (), diff --git a/asyncssh/misc.py b/asyncssh/misc.py index d3765f2..76b8b28 100644 --- a/asyncssh/misc.py +++ b/asyncssh/misc.py @@ -23,6 +23,7 @@ import functools import ipaddress import re +import shlex import socket import sys @@ -269,6 +270,18 @@ def parse_time_interval(value: str) -> float: return _parse_units(value, _time_units, 'time interval') +def split_args(command: str) -> Sequence[str]: + """Split a command string into a list of arguments""" + + lex = shlex.shlex(command, posix=True) + lex.whitespace_split = True + + if sys.platform == 'win32': # pragma: no cover + lex.escape = [] + + return list(lex) + + _ACM = TypeVar('_ACM', bound=AsyncContextManager, covariant=True) class _ACMWrapper(Generic[_ACM]):