Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

passing context to the tokenprovider func #597

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions accesstokenconnector.go
Expand Up @@ -16,13 +16,13 @@ var _ driver.Connector = &accessTokenConnector{}
type accessTokenConnector struct {
Connector

accessTokenProvider func() (string, error)
accessTokenProvider func(ctx context.Context) (string, error)
}

// NewAccessTokenConnector creates a new connector from a DSN and a token provider.
// The token provider func will be called when a new connection is requested and should return a valid access token.
// The returned connector may be used with sql.OpenDB.
func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) (driver.Connector, error) {
func NewAccessTokenConnector(dsn string, tokenProvider func(ctx context.Context) (string, error)) (driver.Connector, error) {
if tokenProvider == nil {
return nil, errors.New("mssql: tokenProvider cannot be nil")
}
Expand All @@ -42,7 +42,7 @@ func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) (
// Connect returns a new database connection
func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) {
var err error
c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider()
c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider(ctx)
if err != nil {
return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err)
}
Expand Down
8 changes: 4 additions & 4 deletions accesstokenconnector_test.go
Expand Up @@ -13,10 +13,10 @@ import (

func TestNewAccessTokenConnector(t *testing.T) {
dsn := "Server=server.database.windows.net;Database=db"
tp := func() (string, error) { return "token", nil }
tp := func(ctx context.Context) (string, error) { return "token", nil }
type args struct {
dsn string
tokenProvider func() (string, error)
tokenProvider func(ctx context.Context) (string, error)
}
tests := []struct {
name string
Expand Down Expand Up @@ -44,7 +44,7 @@ func TestNewAccessTokenConnector(t *testing.T) {
if tc.accessTokenProvider == nil {
return fmt.Errorf("Expected tokenProvider to not be nil")
}
t, err := tc.accessTokenProvider()
t, err := tc.accessTokenProvider(context.TODO())
if t != "token" || err != nil {
return fmt.Errorf("Unexpected results from tokenProvider: %v, %v", t, err)
}
Expand Down Expand Up @@ -80,7 +80,7 @@ func TestNewAccessTokenConnector(t *testing.T) {
func TestAccessTokenConnectorFailsToConnectIfNoAccessToken(t *testing.T) {
errorText := "This is a test"
dsn := "Server=server.database.windows.net;Database=db"
tp := func() (string, error) { return "", errors.New(errorText) }
tp := func(ctx context.Context) (string, error) { return "", errors.New(errorText) }
sut, err := NewAccessTokenConnector(dsn, tp)
if err != nil {
t.Fatalf("expected err==nil, but got %+v", err)
Expand Down
2 changes: 1 addition & 1 deletion examples/azuread-accesstoken/go.mod
Expand Up @@ -4,5 +4,5 @@ go 1.13

require (
github.com/Azure/go-autorest/autorest/adal v0.8.1
github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73
github.com/denisenkom/go-mssqldb v0.0.0-20200831201914-36b6ff1bbc10
)
8 changes: 4 additions & 4 deletions examples/azuread-accesstoken/go.sum
@@ -1,7 +1,5 @@
github.com/Azure/go-autorest v13.3.2+incompatible h1:VxzPyuhtnlBOzc4IWCZHqpyH2d+QMLQEuy3wREyY4oc=
github.com/Azure/go-autorest/autorest v0.9.0 h1:MRvx8gncNaXJqOoLmhNjUAKh33JJF8LyxPhomEtOsjs=
github.com/Azure/go-autorest/autorest v0.9.0/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI=
github.com/Azure/go-autorest/autorest v0.9.4 h1:1cM+NmKw91+8h5vfjgzK4ZGLuN72k87XVZBWyGwNjUM=
github.com/Azure/go-autorest/autorest/adal v0.5.0/go.mod h1:8Z9fGy2MpX0PvDjB1pEgQTmVqjGhiHBW7RJJEciWzS0=
github.com/Azure/go-autorest/autorest/adal v0.8.1 h1:pZdL8o72rK+avFWl+p9nE8RWi1JInZrWJYlnpfXJwHk=
github.com/Azure/go-autorest/autorest/adal v0.8.1/go.mod h1:ZjhuQClTqx435SRJ2iMlOxPYt3d2C/T/7TiQCVZSn3Q=
Expand All @@ -10,12 +8,14 @@ github.com/Azure/go-autorest/autorest/date v0.2.0 h1:yW+Zlqf26583pE43KhfnhFcdmSW
github.com/Azure/go-autorest/autorest/date v0.2.0/go.mod h1:vcORJHLJEh643/Ioh9+vPmf1Ij9AEBM5FuBIXLmIy0g=
github.com/Azure/go-autorest/autorest/mocks v0.1.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0=
github.com/Azure/go-autorest/autorest/mocks v0.2.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0=
github.com/Azure/go-autorest/autorest/mocks v0.3.0 h1:qJumjCaCudz+OcqE9/XtEPfvtOjOmKaui4EOpFI6zZc=
github.com/Azure/go-autorest/autorest/mocks v0.3.0/go.mod h1:a8FDP3DYzQ4RYfVAxAN3SVSiiO77gL2j2ronKKP0syM=
github.com/Azure/go-autorest/logger v0.1.0 h1:ruG4BSDXONFRrZZJ2GUXDiUyVpayPmb1GnWeHDdaNKY=
github.com/Azure/go-autorest/logger v0.1.0/go.mod h1:oExouG+K6PryycPJfVSxi/koC6LSNgds39diKLz7Vrc=
github.com/Azure/go-autorest/tracing v0.5.0 h1:TRn4WjSnkcSy5AEG3pnbtFSwNtwzjr4VYyQflFE619k=
github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk=
github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73 h1:OGNva6WhsKst5OZf7eZOklDztV3hwtTHovdrLHV+MsA=
github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/denisenkom/go-mssqldb v0.0.0-20200831201914-36b6ff1bbc10 h1:uuDqxM2PbeYyXcKIo/IP0ZLGDzougMipEBBrCOzr50w=
github.com/denisenkom/go-mssqldb v0.0.0-20200831201914-36b6ff1bbc10/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
Expand Down
7 changes: 4 additions & 3 deletions examples/azuread-accesstoken/managed_identity.go
@@ -1,6 +1,7 @@
package main

import (
"context"
"database/sql"
"flag"
"fmt"
Expand Down Expand Up @@ -63,7 +64,7 @@ func main() {
fmt.Printf("bye\n")
}

func getMSITokenProvider() (func() (string, error), error) {
func getMSITokenProvider() (func(ctx context.Context) (string, error), error) {
msiEndpoint, err := adal.GetMSIEndpoint()
if err != nil {
return nil, err
Expand All @@ -74,8 +75,8 @@ func getMSITokenProvider() (func() (string, error), error) {
return nil, err
}

return func() (string, error) {
msi.EnsureFresh()
return func(ctx context.Context) (string, error) {
msi.EnsureFreshWithContext(ctx)
token := msi.OAuthToken()
return token, nil
}, nil
Expand Down