Skip to content

Commit

Permalink
Merge pull request #497 from amir-khassaia/feat/http-connect-proxy-su…
Browse files Browse the repository at this point in the history
…pport

Add client option `SetConnectionAttemptHandler` that will be called prior to dialling a broker. This enables a connection specific tls.Config to be set (providing better support for proxies and tls with multiple brokers). In addition a new example has been added that demonstrates how to connect via a proxy (using SNI).
  • Loading branch information
MattBrittan committed Apr 22, 2021
2 parents a140ed8 + 4c25813 commit c15e250
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 6 deletions.
13 changes: 9 additions & 4 deletions client.go
Expand Up @@ -379,8 +379,13 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) {
cm := newConnectMsgFromOptions(&c.options, broker)
DEBUG.Println(CLI, "about to write new connect msg")
CONN:
tlsCfg := c.options.TLSConfig
if c.options.OnConnectAttempt != nil {
DEBUG.Println(CLI, "using custom onConnectAttempt handler...")
tlsCfg = c.options.OnConnectAttempt(broker, c.options.TLSConfig)
}
// Start by opening the network connection (tcp, tls, ws) etc
conn, err = openConnection(broker, c.options.TLSConfig, c.options.ConnectTimeout, c.options.HTTPHeaders, c.options.WebsocketOptions)
conn, err = openConnection(broker, tlsCfg, c.options.ConnectTimeout, c.options.HTTPHeaders, c.options.WebsocketOptions)
if err != nil {
ERROR.Println(CLI, err.Error())
WARN.Println(CLI, "failed to connect to broker, trying next")
Expand All @@ -397,7 +402,7 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) {

// We may be have to attempt the connection with MQTT 3.1
if conn != nil {
conn.Close()
_ = conn.Close()
}
if !c.options.protocolVersionExplicit && protocolVersion == 4 { // try falling back to 3.1?
DEBUG.Println(CLI, "Trying reconnect using MQTT 3.1 protocol")
Expand Down Expand Up @@ -504,8 +509,8 @@ func (c *client) internalConnLost(err error) {
}
}

// startCommsWorkers is called when the connection is up. It starts off all of the routines needed to process incoming and
// outgoing messages.
// startCommsWorkers is called when the connection is up.
// It starts off all of the routines needed to process incoming and outgoing messages.
// Returns true if the comms workers were started (i.e. they were not already running)
func (c *client) startCommsWorkers(conn net.Conn, inboundFromStore <-chan packets.ControlPacket) bool {
DEBUG.Println(CLI, "startCommsWorkers called")
Expand Down
79 changes: 79 additions & 0 deletions cmd/httpproxy/httpproxy.go
@@ -0,0 +1,79 @@
package main

import (
"bufio"
"fmt"
"net"
"net/http"
"net/url"

"golang.org/x/net/proxy"
)

// httpProxy is a HTTP/HTTPS connect capable proxy.
type httpProxy struct {
host string
haveAuth bool
username string
password string
forward proxy.Dialer
}

func (s httpProxy) String() string {
return fmt.Sprintf("HTTP proxy dialer for %s", s.host)
}

func newHTTPProxy(uri *url.URL, forward proxy.Dialer) (proxy.Dialer, error) {
s := new(httpProxy)
s.host = uri.Host
s.forward = forward
if uri.User != nil {
s.haveAuth = true
s.username = uri.User.Username()
s.password, _ = uri.User.Password()
}

return s, nil
}

func (s *httpProxy) Dial(_, addr string) (net.Conn, error) {
reqURL := url.URL{
Scheme: "https",
Host: addr,
}

req, err := http.NewRequest("CONNECT", reqURL.String(), nil)
if err != nil {
return nil, err
}
req.Close = false
if s.haveAuth {
req.SetBasicAuth(s.username, s.password)
}
req.Header.Set("User-Agent", "paho.mqtt")

// Dial and create the client connection.
c, err := s.forward.Dial("tcp", s.host)
if err != nil {
return nil, err
}

err = req.Write(c)
if err != nil {
_ = c.Close()
return nil, err
}

resp, err := http.ReadResponse(bufio.NewReader(c), req)
if err != nil {
_ = c.Close()
return nil, err
}
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
_ = c.Close()
return nil, fmt.Errorf("proxied connection returned an error: %v", resp.Status)
}

return c, nil
}
108 changes: 108 additions & 0 deletions cmd/httpproxy/main.go
@@ -0,0 +1,108 @@
/*
* Copyright (c) 2013 IBM Corp.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* Seth Hoenig
* Allan Stockdill-Mander
* Mike Robertson
*/

package main

import (
"crypto/tls"
"flag"
"fmt"
"golang.org/x/net/proxy"
"log"
"net/url"

// "log"
"os"
"os/signal"
"strconv"
"syscall"
"time"

MQTT "github.com/eclipse/paho.mqtt.golang"
)

func onMessageReceived(_ MQTT.Client, message MQTT.Message) {
fmt.Printf("Received message on topic: %s\nMessage: %s\n", message.Topic(), message.Payload())
}

func init() {
// Pre-register custom HTTP proxy dialers for use with proxy.FromEnvironment
proxy.RegisterDialerType("http", newHTTPProxy)
proxy.RegisterDialerType("https", newHTTPProxy)
}

/**
* Illustrates how to make an MQTT connection with HTTP proxy CONNECT support.
* Specify proxy via environment variable: eg: ALL_PROXY=https://proxy_host:port
*/
func main() {
MQTT.DEBUG = log.New(os.Stdout, "", 0)
MQTT.ERROR = log.New(os.Stderr, "", 0)

c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)

hostname, _ := os.Hostname()

server := flag.String("server", "tcp://127.0.0.1:1883", "The full URL of the MQTT server to "+
"connect to ex: tcp://127.0.0.1:1883")
topic := flag.String("topic", "#", "Topic to subscribe to")
qos := flag.Int("qos", 0, "The QoS to subscribe to messages at")
clientid := flag.String("clientid", hostname+strconv.Itoa(time.Now().Second()), "A clientid for the connection")
username := flag.String("username", "", "A username to authenticate to the MQTT server")
password := flag.String("password", "", "Password to match username")
token := flag.String("token", "", "An optional token credential to authenticate with")
skipVerify := flag.Bool("skipVerify", false, "Controls whether TLS certificate is verified")
flag.Parse()

connOpts := MQTT.NewClientOptions().AddBroker(*server).
SetClientID(*clientid).
SetCleanSession(true).
SetProtocolVersion(4)

if *username != "" {
connOpts.SetUsername(*username)
if *password != "" {
connOpts.SetPassword(*password)
}
} else if *token != "" {
connOpts.SetCredentialsProvider(func() (string, string) {
return "unused", *token
})
}

connOpts.SetTLSConfig(&tls.Config{InsecureSkipVerify: *skipVerify, ClientAuth: tls.NoClientCert})

connOpts.OnConnect = func(c MQTT.Client) {
if token := c.Subscribe(*topic, byte(*qos), onMessageReceived); token.Wait() && token.Error() != nil {
panic(token.Error())
}
}

// Illustrates customized TLS configuration prior to connection attempt
connOpts.OnConnectAttempt = func(broker *url.URL, tlsCfg *tls.Config) *tls.Config {
cfg := tlsCfg.Clone()
cfg.ServerName = broker.Hostname()
return cfg
}

client := MQTT.NewClient(connOpts)
if token := client.Connect(); token.Wait() && token.Error() != nil {
panic(token.Error())
} else {
fmt.Printf("Connected to %s\n", *server)
}

<-c
}
5 changes: 3 additions & 2 deletions netconn.go
Expand Up @@ -30,7 +30,8 @@ import (
// This just establishes the network connection; once established the type of connection should be irrelevant
//

// openConnection opens a network connection using the protocol indicated in the URL. Does not carry out any MQTT specific handshakes
// openConnection opens a network connection using the protocol indicated in the URL.
// Does not carry out any MQTT specific handshakes.
func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header, websocketOptions *WebsocketOptions) (net.Conn, error) {
switch uri.Scheme {
case "ws":
Expand Down Expand Up @@ -81,7 +82,7 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade

err = tlsConn.Handshake()
if err != nil {
conn.Close()
_ = conn.Close()
return nil, err
}

Expand Down
14 changes: 14 additions & 0 deletions options.go
Expand Up @@ -49,6 +49,9 @@ type OnConnectHandler func(Client)
// the initial connection is lost
type ReconnectHandler func(Client, *ClientOptions)

// ConnectionAttemptHandler is invoked prior to making the initial connection.
type ConnectionAttemptHandler func(broker *url.URL, tlsCfg *tls.Config) *tls.Config

// ClientOptions contains configurable options for an Client. Note that these should be set using the
// relevant methods (e.g. AddBroker) rather than directly. See those functions for information on usage.
type ClientOptions struct {
Expand Down Expand Up @@ -79,6 +82,7 @@ type ClientOptions struct {
OnConnect OnConnectHandler
OnConnectionLost ConnectionLostHandler
OnReconnecting ReconnectHandler
OnConnectAttempt ConnectionAttemptHandler
WriteTimeout time.Duration
MessageChannelDepth uint
ResumeSubs bool
Expand Down Expand Up @@ -120,6 +124,7 @@ func NewClientOptions() *ClientOptions {
Store: nil,
OnConnect: nil,
OnConnectionLost: DefaultConnectionLostHandler,
OnConnectAttempt: nil,
WriteTimeout: 0, // 0 represents timeout disabled
ResumeSubs: false,
HTTPHeaders: make(map[string][]string),
Expand Down Expand Up @@ -321,6 +326,15 @@ func (o *ClientOptions) SetReconnectingHandler(cb ReconnectHandler) *ClientOptio
return o
}

// SetConnectionAttemptHandler sets the ConnectionAttemptHandler callback to be executed prior
// to each attempt to connect to an MQTT broker. Returns the *tls.Config that will be used when establishing
// the connection (a copy of the tls.Config from ClientOptions will be passed in along with the broker URL).
// This allows connection specific changes to be made to the *tls.Config.
func (o *ClientOptions) SetConnectionAttemptHandler(onConnectAttempt ConnectionAttemptHandler) *ClientOptions {
o.OnConnectAttempt = onConnectAttempt
return o
}

// SetWriteTimeout puts a limit on how long a mqtt publish should block until it unblocks with a
// timeout error. A duration of 0 never times out. Default never times out
func (o *ClientOptions) SetWriteTimeout(t time.Duration) *ClientOptions {
Expand Down

0 comments on commit c15e250

Please sign in to comment.