Skip to content

Commit

Permalink
Add support for AD federated authentication login sequence (unexported)
Browse files Browse the repository at this point in the history
  • Loading branch information
wrosenuance authored and kardianos committed Dec 18, 2020
1 parent 095ece7 commit 045585d
Show file tree
Hide file tree
Showing 12 changed files with 984 additions and 219 deletions.
6 changes: 6 additions & 0 deletions .gitignore
@@ -1,2 +1,8 @@
/.idea
/.connstr
.vscode
.terraform
*.tfstate*
*.log
*.swp
*~
29 changes: 4 additions & 25 deletions accesstokenconnector.go
Expand Up @@ -6,19 +6,8 @@ import (
"context"
"database/sql/driver"
"errors"
"fmt"
)

var _ driver.Connector = &accessTokenConnector{}

// accessTokenConnector wraps Connector and injects a
// fresh access token when connecting to the database
type accessTokenConnector struct {
Connector

accessTokenProvider func() (string, error)
}

// NewAccessTokenConnector creates a new connector from a DSN and a token provider.
// The token provider func will be called when a new connection is requested and should return a valid access token.
// The returned connector may be used with sql.OpenDB.
Expand All @@ -32,20 +21,10 @@ func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) (
return nil, err
}

c := &accessTokenConnector{
Connector: *conn,
accessTokenProvider: tokenProvider,
}
return c, nil
}

// Connect returns a new database connection
func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) {
var err error
c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider()
if err != nil {
return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err)
conn.params.fedAuthLibrary = fedAuthLibrarySecurityToken
conn.securityTokenProvider = func(ctx context.Context) (string, error) {
return tokenProvider()
}

return c.Connector.Connect(ctx)
return conn, nil
}
12 changes: 6 additions & 6 deletions accesstokenconnector_test.go
Expand Up @@ -30,21 +30,21 @@ func TestNewAccessTokenConnector(t *testing.T) {
dsn: dsn,
tokenProvider: tp},
want: func(c driver.Connector) error {
tc, ok := c.(*accessTokenConnector)
tc, ok := c.(*Connector)
if !ok {
return fmt.Errorf("Expected driver to be of type *accessTokenConnector, but got %T", c)
return fmt.Errorf("Expected driver to be of type *Connector, but got %T", c)
}
p := tc.Connector.params
p := tc.params
if p.database != "db" {
return fmt.Errorf("expected params.database=db, but got %v", p.database)
}
if p.host != "server.database.windows.net" {
return fmt.Errorf("expected params.host=server.database.windows.net, but got %v", p.host)
}
if tc.accessTokenProvider == nil {
return fmt.Errorf("Expected tokenProvider to not be nil")
if tc.securityTokenProvider == nil {
return fmt.Errorf("Expected federated authentication provider to not be nil")
}
t, err := tc.accessTokenProvider()
t, err := tc.securityTokenProvider(context.TODO())
if t != "token" || err != nil {
return fmt.Errorf("Unexpected results from tokenProvider: %v, %v", t, err)
}
Expand Down
11 changes: 7 additions & 4 deletions conn_str.go
Expand Up @@ -37,14 +37,17 @@ type connectParams struct {
failOverPartner string
failOverPort uint64
packetSize uint16
fedAuthAccessToken string
fedAuthLibrary int
fedAuthADALWorkflow byte
}

// default packet size for TDS buffer
const defaultPacketSize = 4096

func parseConnectParams(dsn string) (connectParams, error) {
var p connectParams
p := connectParams{
fedAuthLibrary: fedAuthLibraryReserved,
}

var params map[string]string
if strings.HasPrefix(dsn, "odbc:") {
Expand Down Expand Up @@ -247,8 +250,8 @@ func (p connectParams) toUrl() *url.URL {
}
res := url.URL{
Scheme: "sqlserver",
Host: p.host,
User: url.UserPassword(p.user, p.password),
Host: p.host,
User: url.UserPassword(p.user, p.password),
}
if p.instance != "" {
res.Path = p.instance
Expand Down
82 changes: 82 additions & 0 deletions fedauth.go
@@ -0,0 +1,82 @@
package mssql

import (
"context"
"errors"
)

// Federated authentication library affects the login data structure and message sequence.
const (
// fedAuthLibraryLiveIDCompactToken specifies the Microsoft Live ID Compact Token authentication scheme
fedAuthLibraryLiveIDCompactToken = 0x00

// fedAuthLibrarySecurityToken specifies a token-based authentication where the token is available
// without additional information provided during the login sequence.
fedAuthLibrarySecurityToken = 0x01

// fedAuthLibraryADAL specifies a token-based authentication where a token is obtained during the
// login sequence using the server SPN and STS URL provided by the server during login.
fedAuthLibraryADAL = 0x02

// fedAuthLibraryReserved is used to indicate that no federated authentication scheme applies.
fedAuthLibraryReserved = 0x7F
)

// Federated authentication ADAL workflow affects the mechanism used to authenticate.
const (
// fedAuthADALWorkflowPassword uses a username/password to obtain a token from Active Directory
fedAuthADALWorkflowPassword = 0x01

// fedAuthADALWorkflowPassword uses the Windows identity to obtain a token from Active Directory
fedAuthADALWorkflowIntegrated = 0x02

// fedAuthADALWorkflowMSI uses the managed identity service to obtain a token
fedAuthADALWorkflowMSI = 0x03
)

// newSecurityTokenConnector creates a new connector from a DSN and a token provider.
// When invoked, token provider implementations should contact the security token
// service specified and obtain the appropriate token, or return an error
// to indicate why a token is not available.
// The returned connector may be used with sql.OpenDB.
func newSecurityTokenConnector(dsn string, tokenProvider func(ctx context.Context) (string, error)) (*Connector, error) {
if tokenProvider == nil {
return nil, errors.New("mssql: tokenProvider cannot be nil")
}

conn, err := NewConnector(dsn)
if err != nil {
return nil, err
}

conn.params.fedAuthLibrary = fedAuthLibrarySecurityToken
conn.securityTokenProvider = tokenProvider

return conn, nil
}

// newADALTokenConnector creates a new connector from a DSN and a Active Directory token provider.
// Token provider implementations are called during federated
// authentication login sequences where the server provides a service
// principal name and security token service endpoint that should be used
// to obtain the token. Implementations should contact the security token
// service specified and obtain the appropriate token, or return an error
// to indicate why a token is not available.
//
// The returned connector may be used with sql.OpenDB.
func newActiveDirectoryTokenConnector(dsn string, adalWorkflow byte, tokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)) (*Connector, error) {
if tokenProvider == nil {
return nil, errors.New("mssql: tokenProvider cannot be nil")
}

conn, err := NewConnector(dsn)
if err != nil {
return nil, err
}

conn.params.fedAuthLibrary = fedAuthLibraryADAL
conn.params.fedAuthADALWorkflow = adalWorkflow
conn.adalTokenProvider = tokenProvider

return conn, nil
}
17 changes: 12 additions & 5 deletions mssql.go
Expand Up @@ -58,6 +58,7 @@ func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
if err != nil {
return nil, err
}

return &Connector{
params: params,
driver: d,
Expand Down Expand Up @@ -100,6 +101,12 @@ type Connector struct {
params connectParams
driver *Driver

// callback that can provide a security token during login
securityTokenProvider func(ctx context.Context) (string, error)

// callback that can provide a security token during ADAL login
adalTokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)

// SessionInitSQL is executed after marking a given session to be reset.
// When not present, the next query will still reset the session to the
// database defaults.
Expand Down Expand Up @@ -148,7 +155,7 @@ type Conn struct {
processQueryText bool
connectionGood bool

outs map[string]interface{}
outs map[string]interface{}
}

func (c *Conn) checkBadConn(err error) error {
Expand Down Expand Up @@ -653,9 +660,9 @@ func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
}

type Rows struct {
stmt *Stmt
cols []columnStruct
reader *tokenProcessor
stmt *Stmt
cols []columnStruct
reader *tokenProcessor
nextCols []columnStruct

cancel func()
Expand All @@ -669,7 +676,7 @@ func (rc *Rows) Close() error {
for {
tok, err := rc.reader.nextToken()
if err == nil {
if tok == nil {
if tok == nil {
return nil
} else {
// continue consuming tokens
Expand Down

0 comments on commit 045585d

Please sign in to comment.