Skip to content

Commit

Permalink
DRIVERS-1357: Add socks5srv.py helper script (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
addaleax committed Dec 16, 2021
1 parent aa6c63a commit 7c409f1
Showing 1 changed file with 254 additions and 0 deletions.
254 changes: 254 additions & 0 deletions .evergreen/socks5srv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
#!/usr/bin/env python3
import socketserver
import socket
import select
import re
import argparse

# Usage: python3 socks5srv.py --port port [--auth username:password] [--map 'host:port to host:port' ...]

class AddressRemapper:
"""A helper for remapping (host, port) tuples to new (host, port) tuples
This is useful for Socks5 servers used in testing environments,
because the successful use of the Socks5 proxy can be demonstrated
by being able to 'connect' to a redirected port, which would always
fail without the proxy, even on localhost-only environments
"""

def __init__(self, mappings):
self.mappings = [AddressRemapper.parse_single_mapping(string) for string in mappings]
self.add_dns_remappings()

@staticmethod
def parse_single_mapping(string):
"""Parse a single mapping of the for '{host}:{port} to {host}:{port}'"""

# Accept either [ipv6]:port or host:port
host_re = r"(\[(?P<{0}_ipv6>[^[\]]+)\]|(?P<{0}_host>[^\[]+))"
port_re = r"(?P<{0}_port>\d+)"

src_re = host_re.format('src') + ':' + port_re.format('src')
dst_re = host_re.format('dst') + ':' + port_re.format('dst')
full_re = '^' + src_re + ' to ' + dst_re + '$'

match = re.match(full_re, string)
if match is None:
raise Exception("Mapping {} does not match format '{{host}}:{{port}} to {{host}}:{{port}}'".format(string))

src = ((match.group('src_ipv6') or match.group('src_host')).encode('utf8'), int(match.group('src_port')))
dst = ((match.group('dst_ipv6') or match.group('dst_host')).encode('utf8'), int(match.group('dst_port')))
return (src, dst)

def add_dns_remappings(self):
"""Add mappings for the IP addresses corresponding to hostnames
For example, if there is a mapping (localhost, 1000) to (localhost, 2000),
then this also adds (127.0.0.1, 1000) to (localhost, 2000)."""

for src, dst in self.mappings:
host, port = src
try:
addrs = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
except socket.gaierror:
continue

existing_src_entries = [src for src, dst in self.mappings]
for af, socktype, proto, canonname, sa in addrs:
if af == socket.AF_INET and sa not in existing_src_entries:
self.mappings.append((sa, dst))
elif af == socket.AF_INET6 and sa[:2] not in existing_src_entries:
self.mappings.append((sa[:2], dst))

def remap(self, hostport):
"""Re-map a (host, port) tuple to a new (host, port) tuple if that was requested"""

for src, dst in self.mappings:
if hostport == src:
return dst
return hostport

class Socks5Server(socketserver.ThreadingTCPServer):
"""A simple Socks5 proxy server"""

def __init__(self, server_address, RequestHandlerClass, args):
socketserver.ThreadingTCPServer.__init__(self,
server_address,
RequestHandlerClass)
self.args = args
self.address_remapper = AddressRemapper(args.map)

class Socks5Handler(socketserver.BaseRequestHandler):
"""Request handler for Socks5 connections"""

def finish(self):
"""Called after handle(), always just closes the connection"""

self.request.close()

def read_exact(self, n):
"""Read n bytes from a socket
In Socks5, strings are prefixed with a single byte containing
their length. This method reads a bytes string containing n bytes
(where n can be a number or a bytes object containing that
single byte).
If reading from the client ends prematurely, this returns None.
"""

if type(n) is bytes:
if len(n) == 0:
return None
assert len(n) == 1
n = n[0]

buf = bytearray(n)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < n:
try:
chunk_length = self.request.recv_into(mv[bytes_read:])
except OSError as exc:
return None
if chunk_length == 0:
return None

bytes_read += chunk_length
return bytes(buf)

def create_outgoing_tcp_connection(self, dst, port):
"""Create an outgoing TCP connection to dst:port"""

outgoing = None
for res in socket.getaddrinfo(dst, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
try:
outgoing = socket.socket(af, socktype, proto)
except OSError as msg:
continue
try:
outgoing.connect(sa)
except OSError as msg:
outgoing.close()
continue
break
return outgoing

def handle(self):
"""Handle the Socks5 communication with a freshly connected client"""

# This implements the Socks5 protocol as specified in
# https://datatracker.ietf.org/doc/html/rfc1928
# and username/password authentication as specified in
# https://datatracker.ietf.org/doc/html/rfc1929
# If you prefer HTML tables over ASCII tables, Wikipedia
# also currently has a decent description of the protocol in
# https://en.wikipedia.org/wiki/SOCKS#SOCKS5.

# Receive/send errors are intentionally left unhandled. Closing
# the socket is just fine in that case for us.

# Client greeting
if self.request.recv(1) != b'\x05': # Socks5 only
return
n_auth = self.request.recv(1)
client_auth_methods = self.read_exact(n_auth)
if client_auth_methods is None:
return

# choose either no-auth or username/password
required_auth_method = b'\x00' if self.server.args.auth is None else b'\x02'
if required_auth_method not in client_auth_methods:
self.request.sendall(b'\x05\xff')
return

self.request.sendall(b'\x05' + required_auth_method)
if required_auth_method == b'\x02':
auth_version = self.request.recv(1)
if auth_version != b'\x01': # Only username/password auth v1
return
username_len = self.request.recv(1)
username = self.read_exact(username_len)
password_len = self.request.recv(1)
password = self.read_exact(password_len)
if username is None or password is None:
return
if username.decode('utf8') + ':' + password.decode('utf8') != self.server.args.auth:
return
self.request.sendall(b'\x01\x00') # auth success

if self.request.recv(1) != b'\x05': # Socks5 only
return
if self.request.recv(1) != b'\x01': # Outgoing TCP only
return
if self.request.recv(1) != b'\x00': # Reserved, must be 0
return

addrtype = self.request.recv(1)
dst = None
if addrtype == b'\x01': # IPv4
ipv4raw = self.read_exact(4)
if ipv4raw is not None:
dst = '.'.join(['{}'] * 4).format(*ipv4raw)
elif addrtype == b'\x03': # Domain
domain_len = self.request.recv(1)
dst = self.read_exact(domain_len)
elif addrtype == b'\x04': # IPv6
ipv6raw = self.read_exact(16)
if ipv6raw is not None:
dst = ':'.join(['{:0>2x}{:0>2x}'] * 8).format(*ipv6raw)
else:
return

if dst is None:
return

portraw = self.read_exact(2)
port = portraw[0] * 256 + portraw[1]

(dst, port) = self.server.address_remapper.remap((dst, port))

outgoing = self.create_outgoing_tcp_connection(dst, port)
if outgoing is None:
self.request.sendall(b'\x05\x01\x00') # just report a general failure
return
# success response, do not bother actually stating the locally bound
# host/port address and instead always say 127.0.0.1:4096.
# for our use case, the client will not be making meaningful use
# of this anyway
self.request.sendall(b'\x05\x00\x00\x01\x7f\x00\x00\x01\x10\x00')

self.raw_proxy(self.request, outgoing)

def raw_proxy(self, a, b):
"""Proxy data between sockets a and b as-is"""

with a, b:
while True:
try:
(readable, _, _) = select.select([a, b], [], [])
except (select.error, ValueError):
return

if not readable:
continue
for sock in readable:
buf = sock.recv(4096)
if buf == b'':
return
if sock is a:
b.sendall(buf)
else:
a.sendall(buf)

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Start a Socks5 proxy server.')
parser.add_argument('--port', type=int, required=True)
parser.add_argument('--auth', type=str)
parser.add_argument('--map', type=str, action='append', default=[])
args = parser.parse_args()

socketserver.TCPServer.allow_reuse_address = True
with Socks5Server(('localhost', args.port), Socks5Handler, args) as server:
server.serve_forever()

0 comments on commit 7c409f1

Please sign in to comment.