Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxy: Add UDP support to SOCKS5 dialer #194

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 42 additions & 30 deletions internal/socks/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ var (
aLongTimeAgo = time.Unix(1, 0)
)

func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
host, port, err := splitHostPort(address)
func (d *Dialer) connect(ctx context.Context, c net.Conn, req Request) (conn net.Conn, _ net.Addr, ctxErr error) {
var udpHeader []byte

host, port, err := splitHostPort(req.DstAddress)
if err != nil {
return nil, err
return c, nil, err
}
if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
c.SetDeadline(deadline)
defer c.SetDeadline(noDeadline)
if req.Cmd != CmdUDPAssociate {
defer c.SetDeadline(noDeadline)
}
}
if ctx != context.Background() {
errCh := make(chan error, 1)
Expand All @@ -47,14 +51,15 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
}()
}

conn = c
b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
b = append(b, Version5)
if len(d.AuthMethods) == 0 || d.Authenticate == nil {
b = append(b, 1, byte(AuthMethodNotRequired))
} else {
ams := d.AuthMethods
if len(ams) > 255 {
return nil, errors.New("too many authentication methods")
return c, nil, errors.New("too many authentication methods")
}
b = append(b, byte(len(ams)))
for _, am := range ams {
Expand All @@ -69,11 +74,11 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
return
}
if b[0] != Version5 {
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
return c, nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
}
am := AuthMethod(b[1])
if am == AuthMethodNoAcceptableMethods {
return nil, errors.New("no acceptable authentication methods")
return c, nil, errors.New("no acceptable authentication methods")
}
if d.Authenticate != nil {
if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
Expand All @@ -82,7 +87,7 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
}

b = b[:0]
b = append(b, Version5, byte(d.cmd), 0)
b = append(b, Version5, byte(req.Cmd), 0)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
b = append(b, AddrTypeIPv4)
Expand All @@ -91,17 +96,23 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
b = append(b, AddrTypeIPv6)
b = append(b, ip6...)
} else {
return nil, errors.New("unknown address type")
return c, nil, errors.New("unknown address type")
}
} else {
if len(host) > 255 {
return nil, errors.New("FQDN too long")
return c, nil, errors.New("FQDN too long")
}
b = append(b, AddrTypeFQDN)
b = append(b, byte(len(host)))
b = append(b, host...)
}
b = append(b, byte(port>>8), byte(port))

if req.Cmd == CmdUDPAssociate {
udpHeader = make([]byte, len(b))
copy(udpHeader[3:], b[3:])
}

if _, ctxErr = c.Write(b); ctxErr != nil {
return
}
Expand All @@ -110,17 +121,18 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
return
}
if b[0] != Version5 {
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
return c, nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
}
if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
return nil, errors.New("unknown error " + cmdErr.String())
return c, nil, errors.New("unknown error " + cmdErr.String())
}
if b[2] != 0 {
return nil, errors.New("non-zero reserved field")
return c, nil, errors.New("non-zero reserved field")
}
l := 2
addrType := b[3]
var a Addr
switch b[3] {
switch addrType {
case AddrTypeIPv4:
l += net.IPv4len
a.IP = make(net.IP, net.IPv4len)
Expand All @@ -129,12 +141,13 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
a.IP = make(net.IP, net.IPv6len)
case AddrTypeFQDN:
if _, err := io.ReadFull(c, b[:1]); err != nil {
return nil, err
return c, nil, err
}
l += int(b[0])
default:
return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
return c, nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
}

if cap(b) < l {
b = make([]byte, l)
} else {
Expand All @@ -149,20 +162,19 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
a.Name = string(b[:len(b)-2])
}
a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
return &a, nil
}

func splitHostPort(address string) (string, int, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return "", 0, err
}
portnum, err := strconv.Atoi(port)
if err != nil {
return "", 0, err
}
if 1 > portnum || portnum > 0xffff {
return "", 0, errors.New("port number out of range " + port)
if req.Cmd == CmdUDPAssociate {
var uc net.Conn
if uc, err = d.proxyDial(ctx, req.UDPNetwork, a.String()); err != nil {
return c, &a, err
}
c.SetDeadline(noDeadline)
go func() {
defer uc.Close()
io.Copy(io.Discard, c)
}()
return udpConn{Conn: uc, socksConn: c, header: udpHeader}, &a, nil
}
return host, portnum, nil

return c, &a, nil
}
116 changes: 111 additions & 5 deletions internal/socks/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package socks_test

import (
"context"
"errors"
"io"
"math/rand"
"net"
Expand All @@ -15,6 +16,7 @@ import (

"golang.org/x/net/internal/socks"
"golang.org/x/net/internal/sockstest"
"golang.org/x/net/nettest"
)

func TestDial(t *testing.T) {
Expand All @@ -33,7 +35,7 @@ func TestDial(t *testing.T) {
Username: "username",
Password: "password",
}).Authenticate
c, err := d.DialContext(context.Background(), ss.TargetAddr().Network(), ss.TargetAddr().String())
c, err := d.DialContext(context.Background(), "tcp", ss.TargetAddrPort().String())
if err != nil {
t.Fatal(err)
}
Expand All @@ -60,7 +62,7 @@ func TestDial(t *testing.T) {
Username: "username",
Password: "password",
}).Authenticate
a, err := d.DialWithConn(context.Background(), c, ss.TargetAddr().Network(), ss.TargetAddr().String())
a, err := d.DialWithConn(context.Background(), c, "tcp", ss.TargetAddrPort().String())
if err != nil {
t.Fatal(err)
}
Expand All @@ -79,7 +81,7 @@ func TestDial(t *testing.T) {
defer cancel()
dialErr := make(chan error)
go func() {
c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
c, err := d.DialContext(ctx, "tcp", ss.TargetAddrPort().String())
if err == nil {
c.Close()
}
Expand All @@ -101,7 +103,7 @@ func TestDial(t *testing.T) {
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
defer cancel()
c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
c, err := d.DialContext(ctx, "tcp", ss.TargetAddrPort().String())
if err == nil {
c.Close()
}
Expand All @@ -119,14 +121,88 @@ func TestDial(t *testing.T) {
for i := 0; i < 2*len(rogueCmdList); i++ {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
defer cancel()
c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
c, err := d.DialContext(ctx, "tcp", ss.TargetAddrPort().String())
if err == nil {
t.Log(c.(*socks.Conn).BoundAddr())
c.Close()
t.Error("should fail")
}
}
})
t.Run("UDPAssociate", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
if err != nil {
t.Fatal(err)
}
defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
c, err := d.DialContext(context.Background(), "udp", ss.TargetAddrPort().String())
if err != nil {
t.Fatal(err)
}
c.Close()
if network := c.RemoteAddr().Network(); network != "udp" {
t.Errorf("RemoteAddr().Network(): expected \"udp\" got %q", network)
}
expected := "127.0.0.1:5964"
if remoteAddr := c.RemoteAddr().String(); remoteAddr != expected {
t.Errorf("RemoteAddr(): expected %q got %q", expected, remoteAddr)
}
if boundAddr := c.(*socks.Conn).BoundAddr().String(); boundAddr != expected {
t.Errorf("BoundAddr(): expected %q got %q", expected, boundAddr)
}
})
t.Run("UDPAssociateWithReadAndWrite", func(t *testing.T) {
rc, cmdFunc, err := packetListenerCmdFunc()
if err != nil {
t.Fatal(err)
}
defer rc.Close()
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, cmdFunc)
if err != nil {
t.Fatal(err)
}
defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
c, err := d.DialContext(context.Background(), "udp", ss.TargetAddrPort().String())
if err != nil {
t.Fatal(err)
}
defer c.Close()
buf := make([]byte, 32)
expected := "HELLO OUTBOUND"
n, err := c.Write([]byte(expected))
if err != nil {
t.Fatal(err)
}
if len(expected) != n {
t.Errorf("Write(): expected %v bytes got %v", len(expected), n)
}
n, addr, err := rc.ReadFrom(buf)
if err != nil {
t.Fatal(err)
}
data, err := socks.SkipUDPHeader(buf[:n])
if err != nil {
t.Fatal(err)
}
if actual := string(data); expected != actual {
t.Errorf("ReadFrom(): expected %q got %q", expected, actual)
}
udpHeader := []byte{0x00, 0x00, 0x00, 0x01, 0x7f, 0x00, 0x00, 0x01, 0x17, 0x4b}
expected = "HELLO INBOUND"
_, err = rc.WriteTo(append(udpHeader, []byte(expected)...), addr)
if err != nil {
t.Fatal(err)
}
n, err = c.Read(buf)
if err != nil {
t.Fatal(err)
}
if actual := string(buf[:n]); expected != actual {
t.Errorf("Read(): expected %q got %q", expected, actual)
}
})
}

func blackholeCmdFunc(rw io.ReadWriter, b []byte) error {
Expand Down Expand Up @@ -168,3 +244,33 @@ func parseDialError(err error) (perr, nerr error) {
perr = err
return
}

func packetListenerCmdFunc() (net.PacketConn, func(io.ReadWriter, []byte) error, error) {
conn, err := nettest.NewLocalPacketListener("udp")
if err != nil {
return nil, nil, err
}
localAddr := conn.LocalAddr().(*net.UDPAddr)
return conn, func(rw io.ReadWriter, b []byte) error {
req, err := sockstest.ParseCmdRequest(b)
if err != nil {
return err
}
if req.Cmd != socks.CmdUDPAssociate {
return errors.New("unexpected command")
}
b, err = sockstest.MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &socks.Addr{IP: localAddr.IP, Port: localAddr.Port})
if err != nil {
return err
}
n, err := rw.Write(b)
if err != nil {
return err
}
if n != len(b) {
return errors.New("short write")
}
_, err = io.Copy(io.Discard, rw)
return err
}, nil
}