diff --git a/balancer/rls/internal/config.go b/balancer/rls/internal/config.go index 305c09106ba..a3deb8906c9 100644 --- a/balancer/rls/internal/config.go +++ b/balancer/rls/internal/config.go @@ -201,7 +201,7 @@ func (*rlsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, if lookupService == "" { return nil, fmt.Errorf("rls: empty lookup_service in service config {%+v}", string(c)) } - parsedTarget := grpcutil.ParseTarget(lookupService) + parsedTarget := grpcutil.ParseTarget(lookupService, false) if parsedTarget.Scheme == "" { parsedTarget.Scheme = resolver.GetDefaultScheme() } diff --git a/clientconn.go b/clientconn.go index cbd671a8531..11ed4af607a 100644 --- a/clientconn.go +++ b/clientconn.go @@ -23,7 +23,6 @@ import ( "errors" "fmt" "math" - "net" "reflect" "strings" "sync" @@ -48,6 +47,7 @@ import ( _ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin. _ "google.golang.org/grpc/internal/resolver/dns" // To register dns resolver. _ "google.golang.org/grpc/internal/resolver/passthrough" // To register passthrough resolver. + _ "google.golang.org/grpc/internal/resolver/unix" // To register unix resolver. ) const ( @@ -191,16 +191,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * } cc.mkp = cc.dopts.copts.KeepaliveParams - if cc.dopts.copts.Dialer == nil { - cc.dopts.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) { - network, addr := parseDialTarget(addr) - return (&net.Dialer{}).DialContext(ctx, network, addr) - } - if cc.dopts.withProxy { - cc.dopts.copts.Dialer = newProxyDialer(cc.dopts.copts.Dialer) - } - } - if cc.dopts.copts.UserAgent != "" { cc.dopts.copts.UserAgent += " " + grpcUA } else { @@ -244,8 +234,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * } // Determine the resolver to use. - cc.parsedTarget = grpcutil.ParseTarget(cc.target) - unixScheme := strings.HasPrefix(cc.target, "unix:") + cc.parsedTarget = grpcutil.ParseTarget(cc.target, cc.dopts.copts.Dialer != nil) channelz.Infof(logger, cc.channelzID, "parsed scheme: %q", cc.parsedTarget.Scheme) resolverBuilder := cc.getResolver(cc.parsedTarget.Scheme) if resolverBuilder == nil { @@ -268,7 +257,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * cc.authority = creds.Info().ServerName } else if cc.dopts.insecure && cc.dopts.authority != "" { cc.authority = cc.dopts.authority - } else if unixScheme { + } else if strings.HasPrefix(cc.target, "unix:") { cc.authority = "localhost" } else { // Use endpoint from "scheme://authority/endpoint" as the default diff --git a/dialoptions.go b/dialoptions.go index a93fcab8f34..e7f86e6d7c8 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -71,7 +71,6 @@ type dialOptions struct { // we need to be able to configure this in tests. resolveNowBackoff func(int) time.Duration resolvers []resolver.Builder - withProxy bool } // DialOption configures how we set up the connection. @@ -325,7 +324,7 @@ func WithInsecure() DialOption { // later release. func WithNoProxy() DialOption { return newFuncDialOption(func(o *dialOptions) { - o.withProxy = false + o.copts.UseProxy = false }) } @@ -595,9 +594,9 @@ func defaultDialOptions() dialOptions { copts: transport.ConnectOptions{ WriteBufferSize: defaultWriteBufSize, ReadBufferSize: defaultReadBufSize, + UseProxy: true, }, resolveNowBackoff: internalbackoff.DefaultExponential.Backoff, - withProxy: true, } } diff --git a/internal/grpcutil/target.go b/internal/grpcutil/target.go index 80b33cdaf90..3e1b22f5a8c 100644 --- a/internal/grpcutil/target.go +++ b/internal/grpcutil/target.go @@ -37,19 +37,32 @@ func split2(s, sep string) (string, string, bool) { } // ParseTarget splits target into a resolver.Target struct containing scheme, -// authority and endpoint. +// authority and endpoint. skipUnixColonParsing indicates that the parse should +// not parse "unix:[path]" cases. This should be true in cases where a custom +// dialer is present, to prevent a behavior change. // // If target is not a valid scheme://authority/endpoint, it returns {Endpoint: // target}. -func ParseTarget(target string) (ret resolver.Target) { +func ParseTarget(target string, skipUnixColonParsing bool) (ret resolver.Target) { var ok bool ret.Scheme, ret.Endpoint, ok = split2(target, "://") if !ok { + if strings.HasPrefix(target, "unix:") && !skipUnixColonParsing { + // Handle the "unix:[path]" case, because splitting on :// only + // handles the "unix://[/absolute/path]" case. Only handle if the + // dialer is nil, to avoid a behavior change with custom dialers. + return resolver.Target{Scheme: "unix", Endpoint: target[len("unix:"):]} + } return resolver.Target{Endpoint: target} } ret.Authority, ret.Endpoint, ok = split2(ret.Endpoint, "/") if !ok { return resolver.Target{Endpoint: target} } + if ret.Scheme == "unix" { + // Add the "/" back in the unix case, so the unix resolver receives the + // actual endpoint. + ret.Endpoint = "/" + ret.Endpoint + } return ret } diff --git a/internal/grpcutil/target_test.go b/internal/grpcutil/target_test.go index 92bdb63d213..562390bfe38 100644 --- a/internal/grpcutil/target_test.go +++ b/internal/grpcutil/target_test.go @@ -32,17 +32,22 @@ func TestParseTarget(t *testing.T) { {Scheme: "passthrough", Authority: "", Endpoint: "/unix/socket/address"}, } { str := test.Scheme + "://" + test.Authority + "/" + test.Endpoint - got := ParseTarget(str) + got := ParseTarget(str, false) if got != test { - t.Errorf("ParseTarget(%q) = %+v, want %+v", str, got, test) + t.Errorf("ParseTarget(%q, false) = %+v, want %+v", str, got, test) + } + got = ParseTarget(str, true) + if got != test { + t.Errorf("ParseTarget(%q, true) = %+v, want %+v", str, got, test) } } } func TestParseTargetString(t *testing.T) { for _, test := range []struct { - targetStr string - want resolver.Target + targetStr string + want resolver.Target + wantWithDialer resolver.Target }{ {targetStr: "", want: resolver.Target{Scheme: "", Authority: "", Endpoint: ""}}, {targetStr: ":///", want: resolver.Target{Scheme: "", Authority: "", Endpoint: ""}}, @@ -70,10 +75,24 @@ func TestParseTargetString(t *testing.T) { {targetStr: "a:/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:/b"}}, {targetStr: "a//b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a//b"}}, {targetStr: "a://b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a://b"}}, + + // Unix cases without custom dialer. + // unix:[local_path] and unix:[/absolute] have different behaviors with + // a custom dialer, to prevent behavior changes with custom dialers. + {targetStr: "unix:domain", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "domain"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:domain"}}, + {targetStr: "unix:/domain", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/domain"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:/domain"}}, } { - got := ParseTarget(test.targetStr) + got := ParseTarget(test.targetStr, false) if got != test.want { - t.Errorf("ParseTarget(%q) = %+v, want %+v", test.targetStr, got, test.want) + t.Errorf("ParseTarget(%q, false) = %+v, want %+v", test.targetStr, got, test.want) + } + wantWithDialer := test.wantWithDialer + if wantWithDialer == (resolver.Target{}) { + wantWithDialer = test.want + } + got = ParseTarget(test.targetStr, true) + if got != wantWithDialer { + t.Errorf("ParseTarget(%q, true) = %+v, want %+v", test.targetStr, got, wantWithDialer) } } } diff --git a/internal/resolver/unix/unix.go b/internal/resolver/unix/unix.go new file mode 100644 index 00000000000..d046e50613d --- /dev/null +++ b/internal/resolver/unix/unix.go @@ -0,0 +1,49 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package unix implements a resolver for unix targets. +package unix + +import ( + "google.golang.org/grpc/internal/transport/networktype" + "google.golang.org/grpc/resolver" +) + +const scheme = "unix" + +type builder struct{} + +func (*builder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + cc.UpdateState(resolver.State{Addresses: []resolver.Address{networktype.Set(resolver.Address{Addr: target.Endpoint}, "unix")}}) + return &nopResolver{}, nil +} + +func (*builder) Scheme() string { + return scheme +} + +type nopResolver struct { +} + +func (*nopResolver) ResolveNow(resolver.ResolveNowOptions) {} + +func (*nopResolver) Close() {} + +func init() { + resolver.Register(&builder{}) +} diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index e73b77a15a8..6a4776bb153 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -33,6 +33,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc/internal/grpcutil" + "google.golang.org/grpc/internal/transport/networktype" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -137,11 +138,18 @@ type http2Client struct { connectionID uint64 } -func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) { +func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, useProxy bool, grpcUA string) (net.Conn, error) { if fn != nil { - return fn(ctx, addr) + return fn(ctx, addr.Addr) } - return (&net.Dialer{}).DialContext(ctx, "tcp", addr) + networkType := "tcp" + if n, ok := networktype.Get(addr); ok { + networkType = n + } + if networkType == "tcp" && useProxy { + return proxyDial(ctx, addr.Addr, grpcUA) + } + return (&net.Dialer{}).DialContext(ctx, networkType, addr.Addr) } func isTemporary(err error) bool { @@ -172,7 +180,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts } }() - conn, err := dial(connectCtx, opts.Dialer, addr.Addr) + conn, err := dial(connectCtx, opts.Dialer, addr, opts.UseProxy, opts.UserAgent) if err != nil { if opts.FailOnNonTempDialError { return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err) diff --git a/internal/transport/networktype/networktype.go b/internal/transport/networktype/networktype.go new file mode 100644 index 00000000000..96967428b51 --- /dev/null +++ b/internal/transport/networktype/networktype.go @@ -0,0 +1,46 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package networktype declares the network type to be used in the default +// dailer. Attribute of a resolver.Address. +package networktype + +import ( + "google.golang.org/grpc/resolver" +) + +// keyType is the key to use for storing State in Attributes. +type keyType string + +const key = keyType("grpc.internal.transport.networktype") + +// Set returns a copy of the provided address with attributes containing networkType. +func Set(address resolver.Address, networkType string) resolver.Address { + address.Attributes = address.Attributes.WithValues(key, networkType) + return address +} + +// Get returns the network type in the resolver.Address and true, or "", false +// if not present. +func Get(address resolver.Address) (string, bool) { + v := address.Attributes.Value(key) + if v == nil { + return "", false + } + return v.(string), true +} diff --git a/proxy.go b/internal/transport/proxy.go similarity index 73% rename from proxy.go rename to internal/transport/proxy.go index f8f69bfb70f..a662bf39a6c 100644 --- a/proxy.go +++ b/internal/transport/proxy.go @@ -16,13 +16,12 @@ * */ -package grpc +package transport import ( "bufio" "context" "encoding/base64" - "errors" "fmt" "io" "net" @@ -34,8 +33,6 @@ import ( const proxyAuthHeaderKey = "Proxy-Authorization" var ( - // errDisabled indicates that proxy is disabled for the address. - errDisabled = errors.New("proxy is disabled for the address") // The following variable will be overwritten in the tests. httpProxyFromEnvironment = http.ProxyFromEnvironment ) @@ -51,9 +48,6 @@ func mapAddress(ctx context.Context, address string) (*url.URL, error) { if err != nil { return nil, err } - if url == nil { - return nil, errDisabled - } return url, nil } @@ -76,7 +70,7 @@ func basicAuth(username, password string) string { return base64.StdEncoding.EncodeToString([]byte(auth)) } -func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL) (_ net.Conn, err error) { +func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL, grpcUA string) (_ net.Conn, err error) { defer func() { if err != nil { conn.Close() @@ -115,32 +109,28 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri return &bufConn{Conn: conn, r: r}, nil } -// newProxyDialer returns a dialer that connects to proxy first if necessary. -// The returned dialer checks if a proxy is necessary, dial to the proxy with the -// provided dialer, does HTTP CONNECT handshake and returns the connection. -func newProxyDialer(dialer func(context.Context, string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) { - return func(ctx context.Context, addr string) (conn net.Conn, err error) { - var newAddr string - proxyURL, err := mapAddress(ctx, addr) - if err != nil { - if err != errDisabled { - return nil, err - } - newAddr = addr - } else { - newAddr = proxyURL.Host - } +// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy +// is necessary, dials, does the HTTP CONNECT handshake, and returns the +// connection. +func proxyDial(ctx context.Context, addr string, grpcUA string) (conn net.Conn, err error) { + newAddr := addr + proxyURL, err := mapAddress(ctx, addr) + if err != nil { + return nil, err + } + if proxyURL != nil { + newAddr = proxyURL.Host + } - conn, err = dialer(ctx, newAddr) - if err != nil { - return - } - if proxyURL != nil { - // proxy is disabled if proxyURL is nil. - conn, err = doHTTPConnectHandshake(ctx, conn, addr, proxyURL) - } + conn, err = (&net.Dialer{}).DialContext(ctx, "tcp", newAddr) + if err != nil { return } + if proxyURL != nil { + // proxy is disabled if proxyURL is nil. + conn, err = doHTTPConnectHandshake(ctx, conn, addr, proxyURL, grpcUA) + } + return } func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error { diff --git a/proxy_test.go b/internal/transport/proxy_test.go similarity index 90% rename from proxy_test.go rename to internal/transport/proxy_test.go index c9604be6235..628b1fddc49 100644 --- a/proxy_test.go +++ b/internal/transport/proxy_test.go @@ -18,7 +18,7 @@ * */ -package grpc +package transport import ( "bufio" @@ -138,15 +138,9 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy defer overwrite(hpfe)() // Dial to proxy server. - dialer := newProxyDialer(func(ctx context.Context, addr string) (net.Conn, error) { - if deadline, ok := ctx.Deadline(); ok { - return net.DialTimeout("tcp", addr, time.Until(deadline)) - } - return net.Dial("tcp", addr) - }) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - c, err := dialer(ctx, blis.Addr().String()) + c, err := proxyDial(ctx, blis.Addr().String(), "test") if err != nil { t.Fatalf("http connect Dial failed: %v", err) } @@ -173,9 +167,6 @@ func (s) TestHTTPConnect(t *testing.T) { if req.Method != http.MethodConnect { return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) } - if req.UserAgent() != grpcUA { - return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA) - } return nil }, ) @@ -195,9 +186,6 @@ func (s) TestHTTPConnectBasicAuth(t *testing.T) { if req.Method != http.MethodConnect { return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) } - if req.UserAgent() != grpcUA { - return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA) - } wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password)) if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr { gotDecoded, _ := base64.StdEncoding.DecodeString(got) diff --git a/internal/transport/transport.go b/internal/transport/transport.go index b74030a9687..9c8f79cb4b2 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -569,6 +569,8 @@ type ConnectOptions struct { ChannelzParentID int64 // MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received. MaxHeaderListSize *uint32 + // UseProxy specifies if a proxy should be used. + UseProxy bool } // NewClientTransport establishes the transport with the required ConnectOptions diff --git a/resolver_conn_wrapper_test.go b/resolver_conn_wrapper_test.go index e125976a535..f13a408937b 100644 --- a/resolver_conn_wrapper_test.go +++ b/resolver_conn_wrapper_test.go @@ -45,9 +45,6 @@ func (s) TestDialParseTargetUnknownScheme(t *testing.T) { }{ {"/unix/socket/address", "/unix/socket/address"}, - // Special test for "unix:///". - {"unix:///unix/socket/address", "unix:///unix/socket/address"}, - // For known scheme. {"passthrough://a.server.com/google.com", "google.com"}, } { diff --git a/rpc_util.go b/rpc_util.go index f0609f2a4e9..ea5bb8d0c2b 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -27,7 +27,6 @@ import ( "io" "io/ioutil" "math" - "net/url" "strings" "sync" "time" @@ -872,40 +871,6 @@ func setCallInfoCodec(c *callInfo) error { return nil } -// parseDialTarget returns the network and address to pass to dialer -func parseDialTarget(target string) (net string, addr string) { - net = "tcp" - - m1 := strings.Index(target, ":") - m2 := strings.Index(target, ":/") - - // handle unix:addr which will fail with url.Parse - if m1 >= 0 && m2 < 0 { - if n := target[0:m1]; n == "unix" { - net = n - addr = target[m1+1:] - return net, addr - } - } - if m2 >= 0 { - t, err := url.Parse(target) - if err != nil { - return net, target - } - scheme := t.Scheme - addr = t.Path - if scheme == "unix" { - net = scheme - if addr == "" { - addr = t.Host - } - return net, addr - } - } - - return net, target -} - // channelzData is used to store channelz related data for ClientConn, addrConn and Server. // These fields cannot be embedded in the original structs (e.g. ClientConn), since to do atomic // operation on int64 variable on 32-bit machine, user is responsible to enforce memory alignment. diff --git a/rpc_util_test.go b/rpc_util_test.go index 2449c23815e..90912d52a22 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -191,27 +191,6 @@ func (s) TestToRPCErr(t *testing.T) { } } -func (s) TestParseDialTarget(t *testing.T) { - for _, test := range []struct { - target, wantNet, wantAddr string - }{ - {"unix:etcd:0", "unix", "etcd:0"}, - {"unix:///tmp/unix-3", "unix", "/tmp/unix-3"}, - {"unix://domain", "unix", "domain"}, - {"unix://etcd:0", "unix", "etcd:0"}, - {"unix:///etcd:0", "unix", "/etcd:0"}, - {"passthrough://unix://domain", "tcp", "passthrough://unix://domain"}, - {"https://google.com:443", "tcp", "https://google.com:443"}, - {"dns:///google.com", "tcp", "dns:///google.com"}, - {"/unix/socket/address", "tcp", "/unix/socket/address"}, - } { - gotNet, gotAddr := parseDialTarget(test.target) - if gotNet != test.wantNet || gotAddr != test.wantAddr { - t.Errorf("parseDialTarget(%q) = %s, %s want %s, %s", test.target, gotNet, gotAddr, test.wantNet, test.wantAddr) - } - } -} - // bmEncode benchmarks encoding a Protocol Buffer message containing mSize // bytes. func bmEncode(b *testing.B, mSize int) { diff --git a/test/authority_test.go b/test/authority_test.go index 6cd5d82eec1..11c3c4c1af8 100644 --- a/test/authority_test.go +++ b/test/authority_test.go @@ -21,17 +21,19 @@ package test import ( "context" "fmt" + "net" "os" "testing" "time" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" testpb "google.golang.org/grpc/test/grpc_testing" ) -func runUnixTest(t *testing.T, address, target, expectedAuthority string) { +func runUnixTest(t *testing.T, address, target, expectedAuthority string, dialer func(context.Context, string) (net.Conn, error)) { if err := os.RemoveAll(address); err != nil { t.Fatalf("Error removing socket file %v: %v\n", address, err) } @@ -57,7 +59,11 @@ func runUnixTest(t *testing.T, address, target, expectedAuthority string) { address: address, target: target, } - if err := us.Start(nil); err != nil { + opts := []grpc.DialOption{} + if dialer != nil { + opts = append(opts, grpc.WithContextDialer(dialer)) + } + if err := us.Start(nil, opts...); err != nil { t.Fatalf("Error starting endpoint server: %v", err) return } @@ -70,6 +76,8 @@ func runUnixTest(t *testing.T, address, target, expectedAuthority string) { } } +// TestUnix does end to end tests with the various supported unix target +// formats, ensuring that the authority is set to localhost in every case. func (s) TestUnix(t *testing.T) { tests := []struct { name string @@ -78,19 +86,19 @@ func (s) TestUnix(t *testing.T) { authority string }{ { - name: "Unix1", + name: "UnixRelative", address: "sock.sock", target: "unix:sock.sock", authority: "localhost", }, { - name: "Unix2", + name: "UnixAbsolute", address: "/tmp/sock.sock", target: "unix:/tmp/sock.sock", authority: "localhost", }, { - name: "Unix3", + name: "UnixAbsoluteAlternate", address: "/tmp/sock.sock", target: "unix:///tmp/sock.sock", authority: "localhost", @@ -98,7 +106,44 @@ func (s) TestUnix(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - runUnixTest(t, test.address, test.target, test.authority) + runUnixTest(t, test.address, test.target, test.authority, nil) + }) + } +} + +// TestUnixCustomDialer does end to end tests with various supported unix target +// formats, ensuring that the target sent to the dialer does NOT have the +// "unix:" prefix stripped. +func (s) TestUnixCustomDialer(t *testing.T) { + tests := []struct { + name string + address string + target string + authority string + }{ + { + name: "UnixRelative", + address: "sock.sock", + target: "unix:sock.sock", + authority: "localhost", + }, + { + name: "UnixAbsolute", + address: "/tmp/sock.sock", + target: "unix:/tmp/sock.sock", + authority: "localhost", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dialer := func(ctx context.Context, address string) (net.Conn, error) { + if address != test.target { + return nil, fmt.Errorf("expected target %v in custom dialer, instead got %v", test.target, address) + } + address = address[len("unix:"):] + return (&net.Dialer{}).DialContext(ctx, "unix", address) + } + runUnixTest(t, test.address, test.target, test.authority, dialer) }) } }