Skip to content

Commit

Permalink
Add connection logging to help with debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
wrosenuance committed Dec 16, 2020
1 parent 095ece7 commit 86bf30a
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 5 deletions.
11 changes: 9 additions & 2 deletions conn_str.go
@@ -1,6 +1,7 @@
package mssql

import (
"errors"
"fmt"
"net"
"net/url"
Expand Down Expand Up @@ -38,6 +39,7 @@ type connectParams struct {
failOverPort uint64
packetSize uint16
fedAuthAccessToken string
tlsKeyLogFile string
}

// default packet size for TDS buffer
Expand Down Expand Up @@ -232,6 +234,11 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
}

p.tlsKeyLogFile, ok = params["tls key log file"]
if ok && p.tlsKeyLogFile != "" && p.disableEncryption {
return p, errors.New("Cannot set tlsKeyLogFile when encryption is disabled")
}

return p, nil
}

Expand All @@ -247,8 +254,8 @@ func (p connectParams) toUrl() *url.URL {
}
res := url.URL{
Scheme: "sqlserver",
Host: p.host,
User: url.UserPassword(p.user, p.password),
Host: p.host,
User: url.UserPassword(p.user, p.password),
}
if p.instance != "" {
res.Path = p.instance
Expand Down
5 changes: 3 additions & 2 deletions conn_str_test.go
Expand Up @@ -67,6 +67,7 @@ func TestValidConnectionString(t *testing.T) {
{"trustservercertificate=false", func(p connectParams) bool { return !p.trustServerCertificate }},
{"certificate=abc", func(p connectParams) bool { return p.certificate == "abc" }},
{"hostnameincertificate=abc", func(p connectParams) bool { return p.hostInCertificate == "abc" }},
{"tls key log file=tls.log", func(p connectParams) bool { return p.tlsKeyLogFile == "tls.log" }},
{"connection timeout=3;dial timeout=4;keepalive=5", func(p connectParams) bool {
return p.conn_timeout == 3*time.Second && p.dial_timeout == 4*time.Second && p.keepAlive == 5*time.Second
}},
Expand Down Expand Up @@ -186,10 +187,10 @@ func testConnParams(t testing.TB) connectParams {
}
if len(os.Getenv("HOST")) > 0 && len(os.Getenv("DATABASE")) > 0 {
return connectParams{
host: os.Getenv("HOST"),
host: os.Getenv("HOST"),
instance: os.Getenv("INSTANCE"),
database: os.Getenv("DATABASE"),
user: os.Getenv("SQLUSER"),
user: os.Getenv("SQLUSER"),
password: os.Getenv("SQLPASSWORD"),
logFlags: logFlags,
}
Expand Down
80 changes: 80 additions & 0 deletions log_conn.go
@@ -0,0 +1,80 @@
package mssql

import (
"encoding/hex"
"net"
"strings"
"time"
)

type connLogger struct {
conn net.Conn
readKind, writeKind string
readCount, writeCount int
logger Logger
}

var _ net.Conn = &connLogger{}

func newConnLogger(conn net.Conn, kind string, logger Logger) net.Conn {
if len(kind) > 0 && !strings.HasPrefix(kind, " ") {
kind = " " + kind
}

cl := &connLogger{
conn: conn,
readKind: "R" + kind,
writeKind: "W" + kind,
logger: logger,
}

return cl
}

func (cl *connLogger) Read(p []byte) (n int, err error) {
n, err = cl.conn.Read(p)

if n > 0 {
dump := hex.Dump(p)
cl.logger.Printf("%s %d\n%s", cl.readKind, cl.readCount, dump)
cl.readCount += n
}

return
}

func (cl *connLogger) Write(p []byte) (n int, err error) {
n, err = cl.conn.Write(p)

if n > 0 {
dump := hex.Dump(p)
cl.logger.Printf("%s %d\n%s", cl.writeKind, cl.writeCount, dump)
cl.writeCount += n
}

return
}

func (cl *connLogger) Close() (err error) {
return cl.conn.Close()
}

func (cl *connLogger) LocalAddr() net.Addr {
return cl.conn.LocalAddr()
}

func (cl *connLogger) RemoteAddr() net.Addr {
return cl.conn.RemoteAddr()
}

func (cl *connLogger) SetDeadline(t time.Time) error {
return cl.conn.SetDeadline(t)
}

func (cl *connLogger) SetReadDeadline(t time.Time) error {
return cl.conn.SetReadDeadline(t)
}

func (cl *connLogger) SetWriteDeadline(t time.Time) error {
return cl.conn.SetWriteDeadline(t)
}
121 changes: 121 additions & 0 deletions log_conn_test.go
@@ -0,0 +1,121 @@
package mssql

import (
"net"
"sync/atomic"
"testing"
"time"
)

func TestConnLoggerOperations(t *testing.T) {
clt := &connLoggerTest{}
cl := newConnLogger(clt, "test", nullLogger{})
packet := append(make([]byte, 0, 10), 1, 2, 3, 4, 5)
n, err := cl.Read(packet)
if n != 10 || err != nil {
t.Error("Unexpected return value from call to Read()")
}

n, err = cl.Write(packet)
if n != 5 || err != nil {
t.Error("Unexpected return value from call to Write()")
}

if cl.Close() != nil {
t.Error("Unexpected return value from call to Close()")
}

if cl.LocalAddr() == nil {
t.Error("Unexpected return value from call to LocalAddr()")
}

if cl.RemoteAddr() == nil {
t.Error("Unexpected return value from call to RemoteAddr()")
}

if cl.SetDeadline(time.Now()) != nil {
t.Error("Unexpected return value from call to SetDeadline()")
}

if cl.SetReadDeadline(time.Now()) != nil {
t.Error("Unexpected return value from call to SetReadDeadline()")
}

if cl.SetWriteDeadline(time.Now()) != nil {
t.Error("Unexpected return value from call to SetWriteDeadline()")
}

if atomic.LoadInt32(&clt.calls) != 8 {
t.Error("Unexpected number of calls recorded")
}
}

type connLoggerTest struct {
calls int32
}

var _ net.Conn = &connLoggerTest{}

type addressTest struct {
}

var _ net.Addr = &addressTest{}

type nullLogger struct {
}

var _ Logger = nullLogger{}

func (n nullLogger) Printf(format string, v ...interface{}) {
}

func (n nullLogger) Println(v ...interface{}) {
}

func (a *addressTest) Network() string {
return "test"
}

func (a *addressTest) String() string {
return "test"
}

func (cl *connLoggerTest) Read(p []byte) (int, error) {
atomic.AddInt32(&cl.calls, 1)
return cap(p), nil
}

func (cl *connLoggerTest) Write(p []byte) (int, error) {
atomic.AddInt32(&cl.calls, 1)
return len(p), nil
}

func (cl *connLoggerTest) Close() error {
atomic.AddInt32(&cl.calls, 1)
return nil
}

func (cl *connLoggerTest) LocalAddr() net.Addr {
atomic.AddInt32(&cl.calls, 1)
return &addressTest{}
}

func (cl *connLoggerTest) RemoteAddr() net.Addr {
atomic.AddInt32(&cl.calls, 1)
return &addressTest{}
}

func (cl *connLoggerTest) SetDeadline(t time.Time) error {
atomic.AddInt32(&cl.calls, 1)
return nil
}

func (cl *connLoggerTest) SetReadDeadline(t time.Time) error {
atomic.AddInt32(&cl.calls, 1)
return nil
}

func (cl *connLoggerTest) SetWriteDeadline(t time.Time) error {
atomic.AddInt32(&cl.calls, 1)
return nil
}
20 changes: 19 additions & 1 deletion tds.go
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"io/ioutil"
"net"
"os"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -140,6 +141,7 @@ const (
logParams = 16
logTransaction = 32
logDebug = 64
logTraffic = 128
)

type columnStruct struct {
Expand Down Expand Up @@ -876,6 +878,10 @@ initiate_connection:
return nil, err
}

if p.logFlags&logTraffic != 0 {
conn = newConnLogger(conn, "TCP", log)
}

toconn := newTimeoutConn(conn, p.conn_timeout)

outbuf := newTdsBuffer(p.packetSize, toconn)
Expand Down Expand Up @@ -936,6 +942,14 @@ initiate_connection:
if p.trustServerCertificate {
config.InsecureSkipVerify = true
}
if p.tlsKeyLogFile != "" {
if w, err := os.OpenFile(p.tlsKeyLogFile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0600); err == nil {
defer w.Close()
config.KeyLogWriter = w
} else {
return nil, fmt.Errorf("Cannot open TLS key log file %s: %v", p.tlsKeyLogFile, err)
}
}
config.ServerName = p.hostInCertificate
// fix for https://github.com/denisenkom/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
Expand All @@ -948,7 +962,11 @@ initiate_connection:
tlsConn := tls.Client(&passthrough, &config)
err = tlsConn.Handshake()
passthrough.c = toconn
outbuf.transport = tlsConn
if sess.logFlags&logTraffic != 0 {
outbuf.transport = newConnLogger(tlsConn, "TLS", log)
} else {
outbuf.transport = tlsConn
}
if err != nil {
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
}
Expand Down

0 comments on commit 86bf30a

Please sign in to comment.