Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: resolve IPv4-IPv6 issues #2403

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 8 additions & 7 deletions docker.go
Expand Up @@ -156,11 +156,7 @@ func (c *DockerContainer) PortEndpoint(ctx context.Context, port nat.Port, proto
// Warning: this is based on your Docker host setting. Will fail if using an SSH tunnel
// You can use the "TC_HOST" env variable to set this yourself
func (c *DockerContainer) Host(ctx context.Context) (string, error) {
host, err := c.provider.DaemonHost(ctx)
if err != nil {
return "", err
}
return host, nil
return GetDockerHostIP(), nil
}

// Inspect gets the raw container info, caching the result for subsequent calls
Expand Down Expand Up @@ -189,7 +185,12 @@ func (c *DockerContainer) MappedPort(ctx context.Context, port nat.Port) (nat.Po

ports := inspect.NetworkSettings.Ports

for k, p := range ports {
boundPorts, err := core.BoundPortsFromBindings(ports)
if err != nil {
return "", err
}

for k, p := range boundPorts {
if k.Port() != port.Port() {
continue
}
Expand All @@ -199,7 +200,7 @@ func (c *DockerContainer) MappedPort(ctx context.Context, port nat.Port) (nat.Po
if len(p) == 0 {
continue
}
return nat.NewPort(k.Proto(), p[0].HostPort)
return nat.NewPort(k.Proto(), p.Port())
}

return "", errors.New("port not found")
Expand Down
7 changes: 5 additions & 2 deletions docker_client.go
Expand Up @@ -57,16 +57,19 @@ func (c *DockerClient) Info(ctx context.Context) (system.Info, error) {
API Version: %v
Operating System: %v
Total Memory: %v MB
Resolved Docker Host: %s
Resolved Docker Host: %s - %s
Resolved Docker Socket Path: %s
Test SessionID: %s
Test ProcessID: %s
`

dockerHost := core.ExtractDockerHost(ctx)

Logger.Printf(infoMessage, packagePath,
dockerInfo.ServerVersion, c.Client.ClientVersion(),
dockerInfo.OperatingSystem, dockerInfo.MemTotal/1024/1024,
core.ExtractDockerHost(ctx),
dockerHost,
core.GetDockerHostIPs(),
core.ExtractDockerSocket(ctx),
core.SessionID(),
core.ProcessID(),
Expand Down
52 changes: 52 additions & 0 deletions internal/core/bound_ports.go
@@ -0,0 +1,52 @@
package core

import (
"fmt"

"github.com/docker/go-connections/nat"
)

type BoundPorts map[nat.Port]nat.Port

// BoundPortsFromBindings returns a map of container ports to host ports.
// They are resolved from the port bindings in the inspect response,
// using the host IP addresses of the Docker host.
// This will resolve the issue of the host port being bound to multiple IP addresses
// in the IPv4 and IPv6 case.
func BoundPortsFromBindings(portMap nat.PortMap) (BoundPorts, error) {
hostIPs := GetDockerHostIPs()

boundPorts := make(BoundPorts)

for containerPort, bindings := range portMap {
if len(bindings) == 0 {
continue
}

hostPort, err := resolveHostPortBinding(hostIPs, bindings)
if err != nil {
return nil, fmt.Errorf("failed to resolve host port binding for port %s: %w", containerPort, err)
}

boundPorts[containerPort] = hostPort
}

return boundPorts, nil
}

// resolveHostPortBinding resolves the host port binding for the host IPs.
// It will return the host port for the first matching IP family (IPv4 or IPv6).
func resolveHostPortBinding(hostIPs []HostIP, portBindings []nat.PortBinding) (nat.Port, error) {
for _, hp := range hostIPs {
family := hp.Family

for _, portBinding := range portBindings {
hostIP := newHostIP(portBinding.HostIP)
if hostIP.Family == family {
return nat.Port(portBinding.HostPort), nil
}
}
}

return "", fmt.Errorf("no host port found for host IPs %v", hostIPs)
}
91 changes: 91 additions & 0 deletions internal/core/bound_ports_test.go
@@ -0,0 +1,91 @@
package core

import (
"fmt"
"testing"

"github.com/docker/go-connections/nat"
)

func TestResolveHostPortBinding(t *testing.T) {
type testCase struct {
name string
expectedPort nat.Port
hostIPs []HostIP
bindings []nat.PortBinding
expectedErr error
}

testCases := []testCase{
{
name: "should return IPv6-mapped host port when preferred",
hostIPs: []HostIP{
{Family: IPv6, Address: "::1"},
{Family: IPv4, Address: "127.0.0.1"},
},
bindings: []nat.PortBinding{
{HostIP: "0.0.0.0", HostPort: "50000"},
{HostIP: "::", HostPort: "50001"},
},
expectedPort: nat.Port("50001"),
},
{
name: "should return IPv4-mapped host port when preferred",
hostIPs: []HostIP{
{Family: IPv4, Address: "127.0.0.1"},
{Family: IPv6, Address: "::1"},
},
bindings: []nat.PortBinding{
{HostIP: "0.0.0.0", HostPort: "50000"},
{HostIP: "::", HostPort: "50001"},
},
expectedPort: nat.Port("50000"),
},
{
name: "should return mapped host port when dual stack IP",
hostIPs: []HostIP{
{Family: IPv4, Address: "127.0.0.1"},
{Family: IPv6, Address: "::1"},
},
bindings: []nat.PortBinding{
{HostIP: "", HostPort: "50000"},
},
expectedPort: nat.Port("50000"),
},
{
name: "should throw when no host port available for host IP family",
hostIPs: []HostIP{
{Family: IPv6, Address: "::1"},
},
bindings: []nat.PortBinding{
{HostIP: "0.0.0.0", HostPort: "50000"},
},
expectedPort: nat.Port(""), // that's the zero value returned by ResolveHostPortBinding
expectedErr: fmt.Errorf("no host port found for host IPs [%s (IPv6)]", "::1"),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resolvedPort, err := resolveHostPortBinding(tc.hostIPs, tc.bindings)

switch {
case err == nil && tc.expectedErr == nil:
break
case err == nil && tc.expectedErr != nil:
t.Errorf("did not receive expected error: %s", tc.expectedErr.Error())
return
case err != nil && tc.expectedErr == nil:
t.Errorf("unexpected error: %v", err)
return
case err.Error() != tc.expectedErr.Error():
t.Errorf("errors mismatch: %s != %s", err.Error(), tc.expectedErr.Error())
return
}

if resolvedPort != tc.expectedPort {
t.Errorf("resolved port mismatch: got %s, expected %s", resolvedPort, tc.expectedPort)
}
})
}
}
79 changes: 79 additions & 0 deletions internal/core/docker_host_ips.go
@@ -0,0 +1,79 @@
package core

import (
"context"
"fmt"
"net"
"sync"
)

type IPFamily string

const (
IPv4 IPFamily = "IPv4"
IPv6 IPFamily = "IPv6"
)

var (
hostIPs []HostIP
hostIPsOnce sync.Once
)

type HostIP struct {
Address string
Family IPFamily
}

func (h HostIP) String() string {
return fmt.Sprintf("%s (%s)", h.Address, h.Family)
}

func newHostIP(host string) HostIP {
var hip HostIP

ip := net.ParseIP(host)
if ip == nil {
host = "127.0.0.1"
ip = net.ParseIP(host)
}

hip.Address = host

if ip.To4() != nil {
hip.Family = IPv4
} else if ip.To16() != nil {
hip.Family = IPv6
}

return hip
}

// GetDockerHostIPs returns the IP addresses of the Docker host.
// The function is protected by a sync.Once to avoid unnecessary calculations.
func GetDockerHostIPs() []HostIP {
hostIPsOnce.Do(func() {
dockerHost := ExtractDockerHost(context.Background())
hostIPs = getDockerHostIPs(dockerHost)
})

return hostIPs
}

// getDockerHostIPs returns the IP addresses of the Docker host.
// The function is helpful for testing purposes,
// as it's not protected by the sync.Once.
func getDockerHostIPs(host string) []HostIP {
hip := newHostIP(host)

ips, err := net.LookupIP(hip.Address)
if err != nil {
return []HostIP{hip}
}

var hips = []HostIP{}
for _, ip := range ips {
hips = append(hostIPs, newHostIP(ip.String()))
}

return hips
}
54 changes: 54 additions & 0 deletions internal/core/docker_host_ips_test.go
@@ -0,0 +1,54 @@
package core

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestGetDockerHostIPs(t *testing.T) {
type args struct {
host string
}
tests := []struct {
name string
args args
hostIps []HostIP
}{
{
name: "should return a list of resolved host IPs when host is not an IP",
args: args{
host: "localhost",
},
hostIps: []HostIP{{Address: "127.0.0.1", Family: IPv4}},
},
{
name: "should return host IP and v4 family when host is an IPv4 IP",
args: args{
host: "127.0.0.1",
},
hostIps: []HostIP{{Address: "127.0.0.1", Family: IPv4}},
},
{
name: "should return host IP and v4 family when host is an IPv4 IP with tcp schema",
args: args{
host: "tcp://127.0.0.1:64692",
},
hostIps: []HostIP{{Address: "127.0.0.1", Family: IPv4}},
},
{
name: "should return host IP and v6 family when host is an IPv6 IP",
args: args{
host: "::1",
},
hostIps: []HostIP{{Address: "::1", Family: IPv6}},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hips := getDockerHostIPs(tt.args.host)
assert.Equal(t, tt.hostIps, hips)
})
}
}
8 changes: 4 additions & 4 deletions modules/cockroachdb/cockroachdb_test.go
Expand Up @@ -17,13 +17,13 @@ import (

func TestCockroach_Insecure(t *testing.T) {
suite.Run(t, &AuthNSuite{
url: "postgres://root@localhost:xxxxx/defaultdb?sslmode=disable",
url: "postgres://root@" + testcontainers.GetDockerHostIP() + ":xxxxx/defaultdb?sslmode=disable",
})
}

func TestCockroach_NotRoot(t *testing.T) {
suite.Run(t, &AuthNSuite{
url: "postgres://test@localhost:xxxxx/defaultdb?sslmode=disable",
url: "postgres://test@" + testcontainers.GetDockerHostIP() + ":xxxxx/defaultdb?sslmode=disable",
opts: []testcontainers.ContainerCustomizer{
cockroachdb.WithUser("test"),
},
Expand All @@ -32,7 +32,7 @@ func TestCockroach_NotRoot(t *testing.T) {

func TestCockroach_Password(t *testing.T) {
suite.Run(t, &AuthNSuite{
url: "postgres://foo:bar@localhost:xxxxx/defaultdb?sslmode=disable",
url: "postgres://foo:bar@" + testcontainers.GetDockerHostIP() + ":xxxxx/defaultdb?sslmode=disable",
opts: []testcontainers.ContainerCustomizer{
cockroachdb.WithUser("foo"),
cockroachdb.WithPassword("bar"),
Expand All @@ -45,7 +45,7 @@ func TestCockroach_TLS(t *testing.T) {
require.NoError(t, err)

suite.Run(t, &AuthNSuite{
url: "postgres://root@localhost:xxxxx/defaultdb?sslmode=verify-full",
url: "postgres://root@" + testcontainers.GetDockerHostIP() + ":xxxxx/defaultdb?sslmode=verify-full",
opts: []testcontainers.ContainerCustomizer{
cockroachdb.WithTLS(tlsCfg),
},
Expand Down
2 changes: 1 addition & 1 deletion modules/cockroachdb/examples_test.go
Expand Up @@ -45,5 +45,5 @@ func ExampleRunContainer() {

// Output:
// true
// postgres://root@localhost:xxx/defaultdb?sslmode=disable
// postgres://root@127.0.0.1:xxx/defaultdb?sslmode=disable
}
2 changes: 1 addition & 1 deletion modules/postgres/postgres_test.go
Expand Up @@ -96,7 +96,7 @@ func TestPostgres(t *testing.T) {
// Ensure connection string is using generic format
id, err := container.MappedPort(ctx, "5432/tcp")
require.NoError(t, err)
assert.Equal(t, fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable&application_name=test", user, password, "localhost", id.Port(), dbname), connStr)
assert.Equal(t, fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable&application_name=test", user, password, testcontainers.GetDockerHostIP(), id.Port(), dbname), connStr)

// perform assertions
db, err := sql.Open("postgres", connStr)
Expand Down