Skip to content

Commit

Permalink
Add Auth Spec tests
Browse files Browse the repository at this point in the history
GODRIVER-142

Change-Id: I75dd6c46bbf37913e7e19c63aad1edd58108f75a
  • Loading branch information
rfblue2 authored and Roland Fong committed Jun 12, 2018
1 parent 6ef364f commit bb548a5
Show file tree
Hide file tree
Showing 5 changed files with 576 additions and 19 deletions.
4 changes: 0 additions & 4 deletions core/auth/plain.go
Expand Up @@ -16,10 +16,6 @@ import (
const PLAIN = "PLAIN"

func newPlainAuthenticator(cred *Cred) (Authenticator, error) {
if cred.Source != "" && cred.Source != "$external" {
return nil, newAuthError("PLAIN source must be empty or $external", nil)
}

return &PlainAuthenticator{
Username: cred.Username,
Password: cred.Password,
Expand Down
120 changes: 120 additions & 0 deletions core/connstring/connstring.go
Expand Up @@ -247,6 +247,16 @@ func (p *parser) parse(original string) error {
}
}

err = p.setDefaultAuthParams(extractedDatabase.db)
if err != nil {
return err
}

err = p.validateAuth()
if err != nil {
return err
}

// Check for invalid write concern (i.e. w=0 and j=true)
if p.WNumberSet && p.WNumber == 0 && p.JSet && p.J {
return writeconcern.ErrInconsistent
Expand All @@ -260,6 +270,116 @@ func (p *parser) parse(original string) error {
return nil
}

func (p *parser) setDefaultAuthParams(dbName string) error {
switch strings.ToLower(p.AuthMechanism) {
case "plain":
if p.AuthSource == "" {
p.AuthSource = dbName
if p.AuthSource == "" {
p.AuthSource = "$external"
}
}
case "gssapi":
if p.AuthMechanismProperties == nil {
p.AuthMechanismProperties = map[string]string{
"SERVICE_NAME": "mongodb",
}
} else if v, ok := p.AuthMechanismProperties["SERVICE_NAME"]; !ok || v == "" {
p.AuthMechanismProperties["SERVICE_NAME"] = "mongodb"
}
fallthrough
case "mongodb-x509":
if p.AuthSource == "" {
p.AuthSource = "$external"
} else if p.AuthSource != "$external" {
return fmt.Errorf("auth source must be $external")
}
case "mongodb-cr":
fallthrough
case "scram-sha-1":
fallthrough
case "scram-sha-256":
if p.AuthSource == "" {
p.AuthSource = dbName
if p.AuthSource == "" {
p.AuthSource = "admin"
}
}
case "":
if p.AuthSource == "" {
p.AuthSource = "admin"
}
default:
return fmt.Errorf("invalid auth mechanism")
}
return nil
}

func (p *parser) validateAuth() error {
switch strings.ToLower(p.AuthMechanism) {
case "mongodb-cr":
if p.Username == "" {
return fmt.Errorf("username required for MONGO-CR")
}
if p.Password == "" {
return fmt.Errorf("password required for MONGO-CR")
}
if p.AuthMechanismProperties != nil {
return fmt.Errorf("MONGO-CR cannot have mechanism properties")
}
case "mongodb-x509":
if p.Password != "" {
return fmt.Errorf("password cannot be specified for MONGO-X509")
}
if p.AuthMechanismProperties != nil {
return fmt.Errorf("MONGO-X509 cannot have mechanism properties")
}
case "gssapi":
if p.Username == "" {
return fmt.Errorf("username required for GSSAPI")
}
for k := range p.AuthMechanismProperties {
if k != "SERVICE_NAME" && k != "CANONICALIZE_HOST_NAME" && k != "SERVICE_REALM" {
return fmt.Errorf("invalid auth property for GSSAPI")
}
}
case "plain":
if p.Username == "" {
return fmt.Errorf("username required for PLAIN")
}
if p.Password == "" {
return fmt.Errorf("password required for PLAIN")
}
if p.AuthMechanismProperties != nil {
return fmt.Errorf("PLAIN cannot have mechanism properties")
}
case "scram-sha-1":
if p.Username == "" {
return fmt.Errorf("username required for SCRAM-SHA-1")
}
if p.Password == "" {
return fmt.Errorf("password required for SCRAM-SHA-1")
}
if p.AuthMechanismProperties != nil {
return fmt.Errorf("SCRAM-SHA-1 cannot have mechanism properties")
}
case "scram-sha-256":
if p.Username == "" {
return fmt.Errorf("username required for SCRAM-SHA-256")
}
if p.Password == "" {
return fmt.Errorf("password required for SCRAM-SHA-256")
}
if p.AuthMechanismProperties != nil {
return fmt.Errorf("SCRAM-SHA-256 cannot have mechanism properties")
}
case "":
default:
return fmt.Errorf("invalid auth mechanism")
}
return nil
}

func fetchSeedlistFromSRV(host string) ([]string, error) {
var err error

Expand Down
33 changes: 20 additions & 13 deletions core/connstring/connstring_spec_test.go
Expand Up @@ -42,7 +42,10 @@ type testContainer struct {
Tests []testCase
}

const testsDir string = "../../data/connection-string/"
const connstringTestsDir = "../../data/connection-string/"

// Note a test supporting the deprecated gssapiServiceName property was removed from data/auth/auth_tests.json
const authTestsDir = "../../data/auth/"

func (h *host) toString() string {
switch h.Type {
Expand Down Expand Up @@ -77,8 +80,8 @@ func hostsToStrings(hosts []host) []string {
return out
}

func runTestsInFile(t *testing.T, filename string) {
filepath := path.Join(testsDir, filename)
func runTestsInFile(t *testing.T, dirname string, filename string) {
filepath := path.Join(dirname, filename)
content, err := ioutil.ReadFile(filepath)
require.NoError(t, err)

Expand Down Expand Up @@ -106,7 +109,10 @@ func runTest(t *testing.T, filename string, test *testCase) {
}

require.Equal(t, test.URI, cs.Original)
require.Equal(t, hostsToStrings(test.Hosts), cs.Hosts)

if test.Hosts != nil {
require.Equal(t, hostsToStrings(test.Hosts), cs.Hosts)
}

if test.Auth != nil {
require.Equal(t, test.Auth.Username, cs.Username)
Expand All @@ -118,7 +124,11 @@ func runTest(t *testing.T, filename string, test *testCase) {
require.Equal(t, *test.Auth.Password, cs.Password)
}

require.Equal(t, test.Auth.DB, cs.Database)
if test.Auth.DB != cs.Database {
require.Equal(t, test.Auth.DB, cs.AuthSource)
} else {
require.Equal(t, test.Auth.DB, cs.Database)
}
}

// Check that all options are present.
Expand All @@ -140,14 +150,11 @@ func runTest(t *testing.T, filename string, test *testCase) {

// Test case for all connection string spec tests.
func TestConnStringSpec(t *testing.T) {
entries, err := ioutil.ReadDir(testsDir)
require.NoError(t, err)

for _, entry := range entries {
if entry.IsDir() || path.Ext(entry.Name()) != ".json" {
continue
}
for _, file := range testhelpers.FindJSONFilesInDir(t, connstringTestsDir) {
runTestsInFile(t, connstringTestsDir, file)
}

runTestsInFile(t, entry.Name())
for _, file := range testhelpers.FindJSONFilesInDir(t, authTestsDir) {
runTestsInFile(t, authTestsDir, file)
}
}
4 changes: 2 additions & 2 deletions core/connstring/connstring_test.go
Expand Up @@ -49,11 +49,11 @@ func TestAuthMechanism(t *testing.T) {
}{
{s: "authMechanism=scram-sha-1", expected: "scram-sha-1"},
{s: "authMechanism=mongodb-CR", expected: "mongodb-CR"},
{s: "authMechanism=LDAP", expected: "LDAP"},
{s: "authMechanism=plain", expected: "plain"},
}

for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
s := fmt.Sprintf("mongodb://user:pass@localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.Parse(s)
if test.err {
Expand Down

0 comments on commit bb548a5

Please sign in to comment.