Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add the tls min version to dsn #736

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 22 additions & 3 deletions msdsn/conn_str.go
Expand Up @@ -75,11 +75,11 @@ type Config struct {
PacketSize uint16
}

func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string) (*tls.Config, error) {
func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string, tlsMinVer uint16) (*tls.Config, error) {
config := tls.Config{
ServerName: hostInCertificate,
InsecureSkipVerify: insecureSkipVerify,

MinVersion: tlsMinVer,
// fix for https://github.com/denisenkom/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
Expand Down Expand Up @@ -254,10 +254,29 @@ func Parse(dsn string) (Config, map[string]string, error) {
hostInCertificate = p.Host
p.HostInCertificateProvided = false
}
tlsversion, ok := params["tlsminversion"]
tlsMinVer := uint16(0)
if ok {
tlsversion = strings.ToUpper(tlsversion)
switch tlsversion {
case "TLS1.0":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just "1.0" and "1.2" etc since we already know the parameter name is about TLS?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhiyunliu we'd consider taking this change in the microsoft fork
github.com/microsoft/go-mssqldb

tlsMinVer = tls.VersionTLS10
case "TLS1.1":
tlsMinVer = tls.VersionTLS11
case "TLS1.2":
tlsMinVer = tls.VersionTLS12
/*comment by go1.8 ~ go1.11 has no tls.VersionTLS13
case "TLS1.3":
tlsMinVer = tls.VersionTLS13
*/
default:
tlsMinVer = 0
}
}

if p.Encryption != EncryptionDisabled {
var err error
p.TLSConfig, err = SetupTLS(certificate, trustServerCert, hostInCertificate)
p.TLSConfig, err = SetupTLS(certificate, trustServerCert, hostInCertificate, tlsMinVer)
if err != nil {
return p, params, fmt.Errorf("failed to setup TLS: %w", err)
}
Expand Down
29 changes: 29 additions & 0 deletions msdsn/conn_str_test.go
@@ -1,6 +1,7 @@
package msdsn

import (
"crypto/tls"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -196,3 +197,31 @@ func TestConnParseRoundTripFixed(t *testing.T) {
t.Fatal("Parameters do not match after roundtrip", params, rtParams)
}
}

func TestConnParseWithTlsVersion(t *testing.T) {
tests := []struct {
name string
connStr string
wantCfg *Config
}{
{name: "1.TLS1.0", connStr: "sqlserver://someuser@somehost?tlsminversion=tls1.0", wantCfg: &Config{TLSConfig: &tls.Config{MinVersion: tls.VersionTLS10}}},
{name: "2.TLS1.1", connStr: "sqlserver://someuser@somehost?tlsminversion=tls1.1", wantCfg: &Config{TLSConfig: &tls.Config{MinVersion: tls.VersionTLS11}}},
{name: "3.TLS1.2", connStr: "sqlserver://someuser@somehost?tlsminversion=tls1.2", wantCfg: &Config{TLSConfig: &tls.Config{MinVersion: tls.VersionTLS12}}},
{name: "4.no tlsminversion parameter", connStr: "sqlserver://someuser@somehost", wantCfg: &Config{TLSConfig: &tls.Config{MinVersion: 0}}},
{name: "5.wrong tlsminversion parameter", connStr: "sqlserver://someuser@somehost?tlsminversion=wrongtlsversion", wantCfg: &Config{TLSConfig: &tls.Config{MinVersion: 0}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, _, err := Parse(tt.connStr)
if err != nil {
t.Errorf("%s Parse Error:%+v", tt.name, err)
return
}
if got.TLSConfig.MinVersion != tt.wantCfg.TLSConfig.MinVersion {
t.Errorf("%s Parse MinVersion not match. want:%d, got:%d", tt.name, tt.wantCfg.TLSConfig.MinVersion, got.TLSConfig.MinVersion)
return
}
})
}

}
140 changes: 76 additions & 64 deletions tds.go
Expand Up @@ -1050,50 +1050,7 @@ func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Co
dialCtx, cancel = context.WithTimeout(ctx, dt)
defer cancel()
}
// if instance is specified use instance resolution service
if len(p.Instance) > 0 && p.Port != 0 && uint64(p.LogFlags)&logDebug != 0 {
// both instance name and port specified
// when port is specified instance name is not used
// you should not provide instance name when you provide port
logger.Log(ctx, msdsn.LogDebug, "WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored")
}
if len(p.Instance) > 0 {
p.Instance = strings.ToUpper(p.Instance)
d := c.getDialer(&p)
instances, err := getInstances(dialCtx, d, p.Host)
if err != nil {
f := "unable to get instances from Sql Server Browser on host %v: %v"
return nil, fmt.Errorf(f, p.Host, err.Error())
}
strport, ok := instances[p.Instance]["tcp"]
if !ok {
f := "no instance matching '%v' returned from host '%v'"
return nil, fmt.Errorf(f, p.Instance, p.Host)
}
port, err := strconv.ParseUint(strport, 0, 16)
if err != nil {
f := "invalid tcp port returned from Sql Server Browser '%v': %v"
return nil, fmt.Errorf(f, strport, err.Error())
}
p.Port = port
}
if p.Port == 0 {
p.Port = defaultServerPort
}

packetSize := p.PacketSize
if packetSize == 0 {
packetSize = defaultPacketSize
}
// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
// NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
// a higher packet size, the server will respond with an ENVCHANGE request to
// alter the packet size to 16383 bytes.
if packetSize < 512 {
packetSize = 512
} else if packetSize > 32767 {
packetSize = 32767
}
err = prepareMSDSN(ctx, c, logger, &p)

initiate_connection:
conn, err := dialConnection(dialCtx, c, p)
Expand All @@ -1103,7 +1060,7 @@ initiate_connection:

toconn := newTimeoutConn(conn, p.ConnTimeout)

outbuf := newTdsBuffer(packetSize, toconn)
outbuf := newTdsBuffer(p.PacketSize, toconn)
sess := tdsSession{
buf: outbuf,
logger: logger,
Expand Down Expand Up @@ -1136,25 +1093,8 @@ initiate_connection:
}

if encrypt != encryptNotSup {
var config *tls.Config
if pc := p.TLSConfig; pc != nil {
config = pc
if config.DynamicRecordSizingDisabled == false {
config = config.Clone()

// fix for https://github.com/denisenkom/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
config.DynamicRecordSizingDisabled = true
}
}
if config == nil {
config, err = msdsn.SetupTLS("", false, p.Host)
if err != nil {
return nil, err
}
}
//refactor tls config build.
config := prepareTLSConfig(p)

// setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream
handshakeConn := tlsHandshakeConn{buf: outbuf}
Expand Down Expand Up @@ -1288,3 +1228,75 @@ func resolveServerPort(port uint64) uint64 {

return port
}

func prepareTLSConfig(p msdsn.Config) (config *tls.Config) {
if pc := p.TLSConfig; pc != nil {
config = pc
if config.DynamicRecordSizingDisabled == false {
config = config.Clone()

// fix for https://github.com/denisenkom/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
config.DynamicRecordSizingDisabled = true
}
}
if config == nil {
//In this scenario, error will not appear
config, _ = msdsn.SetupTLS("", false, p.Host, 0)
}
return
}

func prepareMSDSN(dialCtx context.Context, c *Connector, logger ContextLogger, p *msdsn.Config) (err error) {

// if instance is specified use instance resolution service
if len(p.Instance) > 0 && p.Port != 0 && uint64(p.LogFlags)&logDebug != 0 {
// both instance name and port specified
// when port is specified instance name is not used
// you should not provide instance name when you provide port
logger.Log(dialCtx, msdsn.LogDebug, "WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored")
}
if len(p.Instance) > 0 {
p.Instance = strings.ToUpper(p.Instance)
d := c.getDialer(p)
instances, err := getInstances(dialCtx, d, p.Host)
if err != nil {
const f = "unable to get instances from Sql Server Browser on host %v: %v"
return fmt.Errorf(f, p.Host, err.Error())
}
strport, ok := instances[p.Instance]["tcp"]
if !ok {
const f = "no instance matching '%v' returned from host '%v'"
return fmt.Errorf(f, p.Instance, p.Host)
}
port, err := strconv.ParseUint(strport, 0, 16)
if err != nil {
const f = "invalid tcp port returned from Sql Server Browser '%v': %v"
return fmt.Errorf(f, strport, err.Error())
}
p.Port = port
}
if p.Port == 0 {
p.Port = defaultServerPort
}

packetSize := p.PacketSize
if packetSize == 0 {
packetSize = defaultPacketSize
}
// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
// NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
// a higher packet size, the server will respond with an ENVCHANGE request to
// alter the packet size to 16383 bytes.
if packetSize < 512 {
packetSize = 512
} else if packetSize > 32767 {
packetSize = 32767
}

p.PacketSize = packetSize
return err

}
127 changes: 127 additions & 0 deletions tds_test.go
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"context"
"crypto/tls"
"database/sql"
"encoding/hex"
"fmt"
Expand All @@ -12,6 +13,7 @@ import (
"net/url"
"os"
"path"
"reflect"
"runtime"
"sync"
"testing"
Expand Down Expand Up @@ -758,3 +760,128 @@ func runBatch(t testing.TB, p msdsn.Config) {
}
}
}

func Test_prepareTLSConfig(t *testing.T) {

tests := []struct {
name string
p msdsn.Config
wantConfig *tls.Config
wantErr bool
}{
{name: "1.TLSConfig is null ", p: msdsn.Config{Host: "testserver"}, wantConfig: &tls.Config{ServerName: "testserver"}, wantErr: false},
{name: "2.TLSConfig not null ,DynamicRecordSizingDisabled=false", p: msdsn.Config{TLSConfig: &tls.Config{DynamicRecordSizingDisabled: false, ServerName: "testserver", MinVersion: tls.VersionTLS10}}, wantConfig: &tls.Config{ServerName: "testserver", DynamicRecordSizingDisabled: true, MinVersion: tls.VersionTLS10}, wantErr: false},
{name: "3.TLSConfig not null ,DynamicRecordSizingDisabled=true", p: msdsn.Config{TLSConfig: &tls.Config{DynamicRecordSizingDisabled: true, ServerName: "testserver", MinVersion: tls.VersionTLS10}}, wantConfig: &tls.Config{ServerName: "testserver", DynamicRecordSizingDisabled: true, MinVersion: tls.VersionTLS10}, wantErr: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotConfig := prepareTLSConfig(tt.p)

if gotConfig.ServerName != tt.wantConfig.ServerName ||
gotConfig.MinVersion != tt.wantConfig.MinVersion {
t.Errorf("prepareTLSConfig() = %v, want %v", gotConfig, tt.wantConfig)
}
})
}
}

func Test_prepareMSDSN(t *testing.T) {

tl := testLogger{t: t}
defer tl.StopLogging()
SetLogger(&tl)
var newdialerCallErr = func() Dialer {
return NewMockTransportDialer(
[]string{
" 03",
},
[]string{
" 05 00 00 49 6e 73 74 61 6e 63 65 4e 61 6d 65 3b 62 3b 3b",
},
)
}
var newdialerCallErr2 = func() Dialer {
return NewMockTransportDialer(
[]string{
" 03",
},
[]string{
" 05 00 00 49 6e 73 74 61 6e 63 65 4e 61 6d 65 3b 62 3b 74 63 70 3b 61 62 63 3b 3b",
},
)
}
var newdialerCallErr3 = func() Dialer {
return NewMockTransportDialer(
[]string{
" 04",
},
[]string{
" 05 00 00 49 6e 73 74 61 6e 63 65 4e 61 6d 65 3b 62 3b 74 63 70 3b 61 62 63 3b 3b",
},
)
}
var newdialerCallSuc = func() Dialer {
return NewMockTransportDialer(
[]string{
" 03",
},
[]string{
" 05 00 00 49 6e 73 74 61 6e 63 65 4e 61 6d 65 3b 62 3b 74 63 70 3b 31 34 33 33 3b 3b",
},
)
}

type args struct {
ctx context.Context
p *msdsn.Config
}

tests := []struct {
name string
args args
dialCall func() Dialer
wantDialCtx context.Context
wantErr bool
}{
{name: "1.", dialCall: newdialerCallErr, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "test", Port: 1433, LogFlags: msdsn.LogDebug}}, wantDialCtx: nil, wantErr: true},
{name: "2.", dialCall: newdialerCallErr, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "test", Port: 0, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: true},
{name: "3.", dialCall: newdialerCallErr, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "", PacketSize: 1, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: false},
{name: "4.", dialCall: newdialerCallErr, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "", PacketSize: 32768, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: false},
{name: "5.", dialCall: newdialerCallErr, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "", PacketSize: 32768, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: false},
{name: "6.", dialCall: newdialerCallSuc, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "B", PacketSize: 4096, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: false},
{name: "7.", dialCall: newdialerCallErr2, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "B", PacketSize: 4096, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: true},
{name: "8.", dialCall: newdialerCallErr3, args: args{ctx: context.Background(), p: &msdsn.Config{Instance: "B", PacketSize: 4096, LogFlags: msdsn.LogErrors}}, wantDialCtx: nil, wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

c := &Connector{params: *tt.args.p, Dialer: tt.dialCall()}
err := prepareMSDSN(tt.args.ctx, c, driverInstanceNoProcess.logger, tt.args.p)
if (err != nil) != tt.wantErr {
t.Errorf("prepareMSDSN() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

func Test_parseInstances(t *testing.T) {

tests := []struct {
name string
args []byte
want map[string]map[string]string
}{
{name: "1.len<=3", args: []byte(`abc`), want: map[string]map[string]string{}},
{name: "2.len byte[0]!=5", args: []byte{1, 0, 1, 1}, want: map[string]map[string]string{}},
{name: "3.normal-1", args: append([]byte{5, 0, 0}, []byte(`;b;`)...), want: map[string]map[string]string{}},
{name: "3.normal-2", args: append([]byte{5, 0, 0}, []byte(`InstanceName;b;;`)...), want: map[string]map[string]string{"B": map[string]string{"InstanceName": "b"}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseInstances(tt.args); !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseInstances() = %v, want %v", got, tt.want)
}
})
}
}