Skip to content

Commit

Permalink
msdsn package should support azuresql:// driver name
Browse files Browse the repository at this point in the history
  • Loading branch information
dss-vipps committed Aug 25, 2022
1 parent 1598eaf commit f49b1da
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 11 deletions.
12 changes: 9 additions & 3 deletions msdsn/conn_str.go
Expand Up @@ -38,6 +38,7 @@ const (
)

type Config struct {
Scheme string
Port uint64
Host string
Instance string
Expand Down Expand Up @@ -117,7 +118,7 @@ func Parse(dsn string) (Config, map[string]string, error) {
return p, params, err
}
params = parameters
} else if strings.HasPrefix(dsn, "sqlserver://") {
} else if strings.HasPrefix(dsn, "sqlserver://") || strings.HasPrefix(dsn, "azuresql://") {
parameters, err := splitConnectionStringURL(dsn)
if err != nil {
return p, params, err
Expand All @@ -127,6 +128,11 @@ func Parse(dsn string) (Config, map[string]string, error) {
params = splitConnectionString(dsn)
}

p.Scheme = "sqlserver"
if strings.HasPrefix(dsn, "azuresql://") {
p.Scheme = "azuresql"
}

strlog, ok := params["log"]
if ok {
flags, err := strconv.ParseUint(strlog, 10, 64)
Expand Down Expand Up @@ -342,7 +348,7 @@ func (p Config) URL() *url.URL {
}
q.Add("disableRetry", fmt.Sprintf("%t", p.DisableRetry))
res := url.URL{
Scheme: "sqlserver",
Scheme: p.Scheme,
Host: host,
User: url.UserPassword(p.User, p.Password),
}
Expand Down Expand Up @@ -410,7 +416,7 @@ func splitConnectionStringURL(dsn string) (map[string]string, error) {
return res, err
}

if u.Scheme != "sqlserver" {
if u.Scheme != "sqlserver" && u.Scheme != "azuresql" {
return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
}

Expand Down
51 changes: 43 additions & 8 deletions msdsn/conn_str_test.go
Expand Up @@ -136,28 +136,28 @@ func TestValidConnectionString(t *testing.T) {

// URL mode
{"sqlserver://somehost?connection+timeout=30", func(p Config) bool {
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.ConnTimeout == 30*time.Second
return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.ConnTimeout == 30*time.Second
}},
{"sqlserver://someuser@somehost?connection+timeout=30", func(p Config) bool {
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second
return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second
}},
{"sqlserver://someuser:@somehost?connection+timeout=30", func(p Config) bool {
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second
return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second
}},
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost?connection+timeout=30", func(p Config) bool {
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second
return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second
}},
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434?connection+timeout=30", func(p Config) bool {
return p.Host == "somehost" && p.Port == 1434 && p.Instance == "" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second
return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 1434 && p.Instance == "" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second
}},
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434/someinstance?connection+timeout=30", func(p Config) bool {
return p.Host == "somehost" && p.Port == 1434 && p.Instance == "someinstance" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second
return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 1434 && p.Instance == "someinstance" && p.User == "someuser" && p.Password == "foo:/\\!~@;bar" && p.ConnTimeout == 30*time.Second
}},
{"sqlserver://someuser@somehost?disableretry=true", func(p Config) bool {
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.DisableRetry
return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.DisableRetry
}},
{"sqlserver://someuser@somehost?connection+timeout=30&disableretry=1", func(p Config) bool {
return p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry
return p.Scheme == "sqlserver" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.User == "someuser" && p.Password == "" && p.ConnTimeout == 30*time.Second && p.DisableRetry
}},
}
for _, ts := range connStrings {
Expand All @@ -175,6 +175,41 @@ func TestValidConnectionString(t *testing.T) {
}
}

func TestValidConnectionStringWithParams(t *testing.T) {
// Like the test above, but strings where we want to have assertions on params being passed through;
// would be very verbose to change the test above
type testStruct struct {
connStr string
check func(Config, map[string]string) bool
}
connStrings := []testStruct{
{"azuresql://somehost?connection+timeout=30&fedauth=ActiveDirectoryDefault", func(p Config, params map[string]string) bool {
// fedauth is in params, see TestParams
if !(p.Scheme == "azuresql" && p.Host == "somehost" && p.Port == 0 && p.Instance == "" && p.Password == "") {
return false
}
fedauth, ok := params["fedauth"]
if !ok {
return false
}
return fedauth == "ActiveDirectoryDefault"
}},
}
for _, ts := range connStrings {
p, params, err := Parse(ts.connStr)
if err == nil {
t.Logf("Connection string was parsed successfully %s", ts.connStr)
} else {
t.Errorf("Connection string %s failed to parse with error %s", ts.connStr, err)
continue
}

if !ts.check(p, params) {
t.Errorf("Check failed on conn str %s", ts.connStr)
}
}
}

func TestSplitConnectionStringURL(t *testing.T) {
_, err := splitConnectionStringURL("http://bad")
if err == nil {
Expand Down

0 comments on commit f49b1da

Please sign in to comment.