Skip to content

Commit

Permalink
transparent mode implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Snawoot committed Sep 19, 2019
1 parent f285354 commit 881ec21
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 3 deletions.
11 changes: 9 additions & 2 deletions rsp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from sdnotify import SystemdNotifier
import asyncssh

from .sockslistener import SocksListener
from .constants import LogLevel
from . import utils
from .ssh_pool import SSHPool
Expand Down Expand Up @@ -50,6 +49,9 @@ def parse_args():
default=1080,
type=utils.check_port,
help="bind port")
listen_group.add_argument("-T", "--transparent",
action="store_true",
help="transparent mode")

pool_group = parser.add_argument_group('pool options')
pool_group.add_argument("-n", "--pool-size",
Expand Down Expand Up @@ -141,7 +143,11 @@ async def amain(args, loop): # pragma: no cover
"%d steady connections. It will take at least %.2f "
"seconds to reach it's full size.", args.pool_size,
args.pool_size * 1. / args.connect_rate)
server = SocksListener(listen_address=args.bind_address,
if args.transparent:
from .transparentlistener import TransparentListener as Listener
else:
from .sockslistener import SocksListener as Listener
server = Listener(listen_address=args.bind_address,
listen_port=args.bind_port,
timeout=args.timeout,
pool=pool,
Expand All @@ -168,6 +174,7 @@ def main(): # pragma: no cover
with utils.AsyncLoggingHandler(args.logfile) as log_handler:
logger = utils.setup_logger('MAIN', args.verbosity, log_handler)
utils.setup_logger('SocksListener', args.verbosity, log_handler)
utils.setup_logger('TransparentListener', args.verbosity, log_handler)
utils.setup_logger('SSHPool', args.verbosity, log_handler)

logger.info("Starting eventloop...")
Expand Down
19 changes: 19 additions & 0 deletions rsp/baselistener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from abc import ABC, abstractmethod

class BaseListener(ABC):
@abstractmethod
async def start(self):
""" Abstract method """

@abstractmethod
async def stop(self):
""" Abstract method """

@abstractmethod
async def __aenter__(self):
""" Abstract method """

@abstractmethod
async def __aexit__(self, exc_type, exc, tb):
""" Abstract method """

2 changes: 2 additions & 0 deletions rsp/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ def __str__(self):


BUFSIZE = 16 * 1024
SO_ORIGINAL_DST = 80
SOL_IPV6 = 41
3 changes: 2 additions & 1 deletion rsp/sockslistener.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .constants import BUFSIZE
from .utils import detect_af
from .baselistener import BaseListener


class SocksException(Exception):
Expand All @@ -28,7 +29,7 @@ class BadAddress(SocksException):
SOCKS5REQ = struct.Struct('!BBBB')


class SocksListener: # pylint: disable=too-many-instance-attributes
class SocksListener(BaseListener): # pylint: disable=too-many-instance-attributes
def __init__(self, *,
listen_address,
listen_port,
Expand Down
169 changes: 169 additions & 0 deletions rsp/transparentlistener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import asyncio
import logging
import collections
import socket
import ctypes
from functools import partial

from . import constants
from .utils import detect_af
from .baselistener import BaseListener


BUFSIZE = constants.BUFSIZE


def detect_af(addr):
return socket.getaddrinfo(addr,
None,
socket.AF_UNSPEC,
0,
0,
socket.AI_NUMERICHOST)[0][0]


class sockaddr(ctypes.Structure):
_fields_ = [('sa_family', ctypes.c_uint16),
('sa_data', ctypes.c_char * 14),
]


class sockaddr_in(ctypes.Structure):
_fields_ = [('sin_family', ctypes.c_uint16),
('sin_port', ctypes.c_uint16),
('sin_addr', ctypes.c_uint32),
]


sockaddr_size = max(ctypes.sizeof(sockaddr_in), ctypes.sizeof(sockaddr))


class sockaddr_in6(ctypes.Structure):
_fields_ = [('sin6_family', ctypes.c_uint16),
('sin6_port', ctypes.c_uint16),
('sin6_flowinfo', ctypes.c_uint32),
('sin6_addr', ctypes.c_char * 16),
('sin6_scope_id', ctypes.c_uint32),
]


sockaddr6_size = ctypes.sizeof(sockaddr_in6)


def get_orig_dst(sock):
own_addr = sock.getsockname()[0]
own_af = detect_af(own_addr)
if own_af == socket.AF_INET:
buf = sock.getsockopt(socket.SOL_IP, constants.SO_ORIGINAL_DST, sockaddr_size)
sa = sockaddr_in.from_buffer_copy(buf)
addr = socket.ntohl(sa.sin_addr)
addr = str(addr >> 24) + '.' + str((addr >> 16) & 0xFF) + '.' + str((addr >> 8) & 0xFF) + '.' + str(addr & 0xFF)
port = socket.ntohs(sa.sin_port)
return addr, port
elif own_af == socket.AF_INET6:
buf = sock.getsockopt(constants.SOL_IPV6, constants.SO_ORIGINAL_DST, sockaddr6_size)
sa = sockaddr_in6.from_buffer_copy(buf)
addr = socket.inet_ntop(socket.AF_INET6, sa.sin6_addr)
port = socket.ntohs(sa.sin_port)
return addr, port
else:
raise RuntimeError("Unknown address family!")


class TransparentListener(BaseListener): # pylint: disable=too-many-instance-attributes
def __init__(self, *,
listen_address,
listen_port,
pool,
timeout=4,
loop=None):
self._loop = loop if loop is not None else asyncio.get_event_loop()
self._logger = logging.getLogger(self.__class__.__name__)
self._listen_address = listen_address
self._listen_port = listen_port
self._children = set()
self._server = None
self._pool = pool
self._timeout = timeout

async def stop(self):
self._server.close()
await self._server.wait_closed()
while self._children:
children = list(self._children)
self._children.clear()
self._logger.debug("Cancelling %d client handlers...",
len(children))
for task in children:
task.cancel()
await asyncio.wait(children)
# workaround for TCP server keeps spawning handlers for a while
# after wait_closed() completed
await asyncio.sleep(.5)

async def _pump(self, writer, reader):
while True:
try:
data = await reader.read(BUFSIZE)
except asyncio.CancelledError:
raise
except ConnectionResetError:
break
if not data:
break
writer.write(data)

try:
await writer.drain()
except ConnectionResetError:
break
except asyncio.CancelledError:
raise

async def handler(self, reader, writer):
peer_addr = writer.transport.get_extra_info('peername')
self._logger.info("Client %s connected", str(peer_addr))
dst_writer = None
try:
# Instead get dst addr from socket options
sock = writer.transport.get_extra_info('socket')
dst_addr, dst_port = get_orig_dst(sock)
self._logger.info("Client %s requested connection to %s:%s",
peer_addr, dst_addr, dst_port)
async with self._pool.borrow() as ssh_conn:
dst_reader, dst_writer = await asyncio.wait_for(
ssh_conn.open_connection(dst_addr, dst_port),
self._timeout)
await asyncio.gather(self._pump(writer, dst_reader),
self._pump(dst_writer, reader))
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception as exc: # pragma: no cover
self._logger.exception("Connection handler stopped with exception:"
" %s", str(exc))
finally:
self._logger.info("Client %s disconnected", str(peer_addr))
if dst_writer is not None:
dst_writer.close()
writer.close()

async def start(self):
def _spawn(reader, writer):
def task_cb(task, fut):
self._children.discard(task)
task = self._loop.create_task(self.handler(reader, writer))
self._children.add(task)
task.add_done_callback(partial(task_cb, task))

self._server = await asyncio.start_server(_spawn,
self._listen_address,
self._listen_port)
self._logger.info("Transparent Proxy server listening on %s:%d",
self._listen_address, self._listen_port)

async def __aenter__(self):
await self.start()
return self

async def __aexit__(self, exc_type, exc, tb):
await self.stop()

0 comments on commit 881ec21

Please sign in to comment.