Skip to content

Commit

Permalink
Add build TLS config from URI (#98)
Browse files Browse the repository at this point in the history
* Add build TLS config from URI

Add cacertfile, certfile, keyfile, server_name_indication from https://www.rabbitmq.com/uri-query-parameters.html

* use absolute path in URI tests

* add special characters in URI TLS config test
  • Loading branch information
reddec committed Jul 13, 2022
1 parent da48d91 commit ec9c17a
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 7 deletions.
48 changes: 47 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ package amqp091
import (
"bufio"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"reflect"
"strconv"
Expand Down Expand Up @@ -211,7 +214,11 @@ func DialConfig(url string, config Config) (*Connection, error) {

if uri.Scheme == "amqps" {
if config.TLSClientConfig == nil {
config.TLSClientConfig = new(tls.Config)
tlsConfig, err := tlsConfigFromURI(uri)
if err != nil {
return nil, fmt.Errorf("create TLS config from URI: %w", err)
}
config.TLSClientConfig = tlsConfig
}

// If ServerName has not been specified in TLSClientConfig,
Expand Down Expand Up @@ -878,6 +885,45 @@ func (c *Connection) openComplete() error {
return nil
}

// tlsConfigFromURI tries to create TLS configuration based on query parameters.
// Returns default (empty) config in case no suitable client cert and/or client key not provided.
// Returns error in case certificates can not be parsed.
func tlsConfigFromURI(uri URI) (*tls.Config, error) {
var certPool *x509.CertPool
if uri.CACertFile != "" {
data, err := ioutil.ReadFile(uri.CACertFile)
if err != nil {
return nil, fmt.Errorf("read CA certificate: %w", err)
}

certPool = x509.NewCertPool()
certPool.AppendCertsFromPEM(data)
} else if sysPool, err := x509.SystemCertPool(); err != nil {
return nil, fmt.Errorf("load system certificates: %w", err)
} else {
certPool = sysPool
}

if uri.CertFile == "" || uri.KeyFile == "" {
// no client auth (mTLS), just server auth
return &tls.Config{
RootCAs: certPool,
ServerName: uri.ServerName,
}, nil
}

certificate, err := tls.LoadX509KeyPair(uri.CertFile, uri.KeyFile)
if err != nil {
return nil, fmt.Errorf("load client certificate: %w", err)
}

return &tls.Config{
Certificates: []tls.Certificate{certificate},
RootCAs: certPool,
ServerName: uri.ServerName,
}, nil
}

func max(a, b int) int {
if a > b {
return a
Expand Down
34 changes: 28 additions & 6 deletions uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,16 @@ var defaultURI = URI{

// URI represents a parsed AMQP URI string.
type URI struct {
Scheme string
Host string
Port int
Username string
Password string
Vhost string
Scheme string
Host string
Port int
Username string
Password string
Vhost string
CertFile string // client TLS auth - path to certificate (PEM)
CACertFile string // client TLS auth - path to CA certificate (PEM)
KeyFile string // client TLS auth - path to private key (PEM)
ServerName string // client TLS auth - server name
}

// ParseURI attempts to parse the given AMQP URI according to the spec.
Expand All @@ -52,6 +56,17 @@ type URI struct {
// Password: guest
// Vhost: /
//
// Supports TLS query parameters. See https://www.rabbitmq.com/uri-query-parameters.html
//
// certfile: <path/to/client_cert.pem>
// keyfile: <path/to/client_key.pem>
// cacertfile: <path/to/ca.pem>
// server_name_indication: <server name>
//
// If cacertfile is not provided, system CA certificates will be used.
// Mutual TLS (client auth) will be enabled only in case keyfile AND certfile provided.
//
// If Config.TLSClientConfig is set, TLS parameters from URI will be ignored.
func ParseURI(uri string) (URI, error) {
builder := defaultURI

Expand Down Expand Up @@ -113,6 +128,13 @@ func ParseURI(uri string) (URI, error) {
}
}

// see https://www.rabbitmq.com/uri-query-parameters.html
params := u.Query()
builder.CertFile = params.Get("certfile")
builder.KeyFile = params.Get("keyfile")
builder.CACertFile = params.Get("cacertfile")
builder.ServerName = params.Get("server_name_indication")

return builder, nil
}

Expand Down
20 changes: 20 additions & 0 deletions uri_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,23 @@ func TestURIDefaultPortAmqps(t *testing.T) {
t.Fatal("Default port not correct for amqps, got:", uri.Port)
}
}

func TestURITLSConfig(t *testing.T) {
url := "amqps://foo.bar/?certfile=/foo/%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82/cert.pem&keyfile=/foo/%E4%BD%A0%E5%A5%BD/key.pem&cacertfile=C:%5Ccerts%5Cca.pem&server_name_indication=example.com"
uri, err := ParseURI(url)
if err != nil {
t.Fatal("Could not parse")
}
if uri.CertFile != "/foo/привет/cert.pem" {
t.Fatal("Certfile not set")
}
if uri.CACertFile != "C:\\certs\\ca.pem" {
t.Fatal("CA not set")
}
if uri.KeyFile != "/foo/你好/key.pem" {
t.Fatal("Key not set")
}
if uri.ServerName != "example.com" {
t.Fatal("Server name not set")
}
}

0 comments on commit ec9c17a

Please sign in to comment.