From 881ec21cddc7669ea4a21dc6057d4759fe85b32e Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Thu, 19 Sep 2019 22:24:09 +0300 Subject: [PATCH] transparent mode implementation --- rsp/__main__.py | 11 ++- rsp/baselistener.py | 19 +++++ rsp/constants.py | 2 + rsp/sockslistener.py | 3 +- rsp/transparentlistener.py | 169 +++++++++++++++++++++++++++++++++++++ 5 files changed, 201 insertions(+), 3 deletions(-) create mode 100644 rsp/baselistener.py create mode 100644 rsp/transparentlistener.py diff --git a/rsp/__main__.py b/rsp/__main__.py index ab79b16..1676177 100644 --- a/rsp/__main__.py +++ b/rsp/__main__.py @@ -11,7 +11,6 @@ from sdnotify import SystemdNotifier import asyncssh -from .sockslistener import SocksListener from .constants import LogLevel from . import utils from .ssh_pool import SSHPool @@ -50,6 +49,9 @@ def parse_args(): default=1080, type=utils.check_port, help="bind port") + listen_group.add_argument("-T", "--transparent", + action="store_true", + help="transparent mode") pool_group = parser.add_argument_group('pool options') pool_group.add_argument("-n", "--pool-size", @@ -141,7 +143,11 @@ async def amain(args, loop): # pragma: no cover "%d steady connections. It will take at least %.2f " "seconds to reach it's full size.", args.pool_size, args.pool_size * 1. / args.connect_rate) - server = SocksListener(listen_address=args.bind_address, + if args.transparent: + from .transparentlistener import TransparentListener as Listener + else: + from .sockslistener import SocksListener as Listener + server = Listener(listen_address=args.bind_address, listen_port=args.bind_port, timeout=args.timeout, pool=pool, @@ -168,6 +174,7 @@ def main(): # pragma: no cover with utils.AsyncLoggingHandler(args.logfile) as log_handler: logger = utils.setup_logger('MAIN', args.verbosity, log_handler) utils.setup_logger('SocksListener', args.verbosity, log_handler) + utils.setup_logger('TransparentListener', args.verbosity, log_handler) utils.setup_logger('SSHPool', args.verbosity, log_handler) logger.info("Starting eventloop...") diff --git a/rsp/baselistener.py b/rsp/baselistener.py new file mode 100644 index 0000000..dd7cf3d --- /dev/null +++ b/rsp/baselistener.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + +class BaseListener(ABC): + @abstractmethod + async def start(self): + """ Abstract method """ + + @abstractmethod + async def stop(self): + """ Abstract method """ + + @abstractmethod + async def __aenter__(self): + """ Abstract method """ + + @abstractmethod + async def __aexit__(self, exc_type, exc, tb): + """ Abstract method """ + diff --git a/rsp/constants.py b/rsp/constants.py index 746a566..109a2a4 100644 --- a/rsp/constants.py +++ b/rsp/constants.py @@ -15,3 +15,5 @@ def __str__(self): BUFSIZE = 16 * 1024 +SO_ORIGINAL_DST = 80 +SOL_IPV6 = 41 diff --git a/rsp/sockslistener.py b/rsp/sockslistener.py index b7e5a9d..3b7415c 100644 --- a/rsp/sockslistener.py +++ b/rsp/sockslistener.py @@ -7,6 +7,7 @@ from .constants import BUFSIZE from .utils import detect_af +from .baselistener import BaseListener class SocksException(Exception): @@ -28,7 +29,7 @@ class BadAddress(SocksException): SOCKS5REQ = struct.Struct('!BBBB') -class SocksListener: # pylint: disable=too-many-instance-attributes +class SocksListener(BaseListener): # pylint: disable=too-many-instance-attributes def __init__(self, *, listen_address, listen_port, diff --git a/rsp/transparentlistener.py b/rsp/transparentlistener.py new file mode 100644 index 0000000..c78bb5c --- /dev/null +++ b/rsp/transparentlistener.py @@ -0,0 +1,169 @@ +import asyncio +import logging +import collections +import socket +import ctypes +from functools import partial + +from . import constants +from .utils import detect_af +from .baselistener import BaseListener + + +BUFSIZE = constants.BUFSIZE + + +def detect_af(addr): + return socket.getaddrinfo(addr, + None, + socket.AF_UNSPEC, + 0, + 0, + socket.AI_NUMERICHOST)[0][0] + + +class sockaddr(ctypes.Structure): + _fields_ = [('sa_family', ctypes.c_uint16), + ('sa_data', ctypes.c_char * 14), + ] + + +class sockaddr_in(ctypes.Structure): + _fields_ = [('sin_family', ctypes.c_uint16), + ('sin_port', ctypes.c_uint16), + ('sin_addr', ctypes.c_uint32), + ] + + +sockaddr_size = max(ctypes.sizeof(sockaddr_in), ctypes.sizeof(sockaddr)) + + +class sockaddr_in6(ctypes.Structure): + _fields_ = [('sin6_family', ctypes.c_uint16), + ('sin6_port', ctypes.c_uint16), + ('sin6_flowinfo', ctypes.c_uint32), + ('sin6_addr', ctypes.c_char * 16), + ('sin6_scope_id', ctypes.c_uint32), + ] + + +sockaddr6_size = ctypes.sizeof(sockaddr_in6) + + +def get_orig_dst(sock): + own_addr = sock.getsockname()[0] + own_af = detect_af(own_addr) + if own_af == socket.AF_INET: + buf = sock.getsockopt(socket.SOL_IP, constants.SO_ORIGINAL_DST, sockaddr_size) + sa = sockaddr_in.from_buffer_copy(buf) + addr = socket.ntohl(sa.sin_addr) + addr = str(addr >> 24) + '.' + str((addr >> 16) & 0xFF) + '.' + str((addr >> 8) & 0xFF) + '.' + str(addr & 0xFF) + port = socket.ntohs(sa.sin_port) + return addr, port + elif own_af == socket.AF_INET6: + buf = sock.getsockopt(constants.SOL_IPV6, constants.SO_ORIGINAL_DST, sockaddr6_size) + sa = sockaddr_in6.from_buffer_copy(buf) + addr = socket.inet_ntop(socket.AF_INET6, sa.sin6_addr) + port = socket.ntohs(sa.sin_port) + return addr, port + else: + raise RuntimeError("Unknown address family!") + + +class TransparentListener(BaseListener): # pylint: disable=too-many-instance-attributes + def __init__(self, *, + listen_address, + listen_port, + pool, + timeout=4, + loop=None): + self._loop = loop if loop is not None else asyncio.get_event_loop() + self._logger = logging.getLogger(self.__class__.__name__) + self._listen_address = listen_address + self._listen_port = listen_port + self._children = set() + self._server = None + self._pool = pool + self._timeout = timeout + + async def stop(self): + self._server.close() + await self._server.wait_closed() + while self._children: + children = list(self._children) + self._children.clear() + self._logger.debug("Cancelling %d client handlers...", + len(children)) + for task in children: + task.cancel() + await asyncio.wait(children) + # workaround for TCP server keeps spawning handlers for a while + # after wait_closed() completed + await asyncio.sleep(.5) + + async def _pump(self, writer, reader): + while True: + try: + data = await reader.read(BUFSIZE) + except asyncio.CancelledError: + raise + except ConnectionResetError: + break + if not data: + break + writer.write(data) + + try: + await writer.drain() + except ConnectionResetError: + break + except asyncio.CancelledError: + raise + + async def handler(self, reader, writer): + peer_addr = writer.transport.get_extra_info('peername') + self._logger.info("Client %s connected", str(peer_addr)) + dst_writer = None + try: + # Instead get dst addr from socket options + sock = writer.transport.get_extra_info('socket') + dst_addr, dst_port = get_orig_dst(sock) + self._logger.info("Client %s requested connection to %s:%s", + peer_addr, dst_addr, dst_port) + async with self._pool.borrow() as ssh_conn: + dst_reader, dst_writer = await asyncio.wait_for( + ssh_conn.open_connection(dst_addr, dst_port), + self._timeout) + await asyncio.gather(self._pump(writer, dst_reader), + self._pump(dst_writer, reader)) + except asyncio.CancelledError: # pylint: disable=try-except-raise + raise + except Exception as exc: # pragma: no cover + self._logger.exception("Connection handler stopped with exception:" + " %s", str(exc)) + finally: + self._logger.info("Client %s disconnected", str(peer_addr)) + if dst_writer is not None: + dst_writer.close() + writer.close() + + async def start(self): + def _spawn(reader, writer): + def task_cb(task, fut): + self._children.discard(task) + task = self._loop.create_task(self.handler(reader, writer)) + self._children.add(task) + task.add_done_callback(partial(task_cb, task)) + + self._server = await asyncio.start_server(_spawn, + self._listen_address, + self._listen_port) + self._logger.info("Transparent Proxy server listening on %s:%d", + self._listen_address, self._listen_port) + + async def __aenter__(self): + await self.start() + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.stop()