From 44a150f3593441156b2d568bcad45f66e8f7339c Mon Sep 17 00:00:00 2001 From: hzw Date: Thu, 3 Sep 2020 20:27:05 +0800 Subject: [PATCH] add required_if and required_unless --- README.md | 2 + baked_in.go | 59 +++++++++++++++++ doc.go | 34 ++++++++++ validator_instance.go | 4 +- validator_test.go | 145 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 243 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f5528266..f71bbafd 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,8 @@ Baked-in Validations | min | Minimum | | oneof | One Of | | required | Required | +| required_if | Required If | +| required_unless | Required Unless | | required_with | Required With | | required_with_all | Required With All | | required_without | Required Without | diff --git a/baked_in.go b/baked_in.go index 36e80572..8d80b7a3 100644 --- a/baked_in.go +++ b/baked_in.go @@ -64,6 +64,8 @@ var ( // or even disregard and use your own map if so desired. bakedInValidators = map[string]Func{ "required": hasValue, + "required_if": requiredIf, + "required_unless": requiredUnless, "required_with": requiredWith, "required_with_all": requiredWithAll, "required_without": requiredWithout, @@ -1383,6 +1385,63 @@ func requireCheckFieldKind(fl FieldLevel, param string, defaultNotFoundValue boo } } +// requireCheckFieldValue is a func for check field value +func requireCheckFieldValue(fl FieldLevel, param string, value string, defaultNotFoundValue bool) bool { + field, kind, _, found := fl.GetStructFieldOKAdvanced2(fl.Parent(), param) + if !found { + return defaultNotFoundValue + } + + switch kind { + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return field.Int() == asInt(value) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return field.Uint() == asUint(value) + + case reflect.Float32, reflect.Float64: + return field.Float() == asFloat(value) + + case reflect.Slice, reflect.Map, reflect.Array: + return int64(field.Len()) == asInt(value) + } + + // default reflect.String: + return field.String() == value +} + +// requiredIf is the validation function +// The field under validation must be present and not empty only if all the other specified fields are equal to the value following with the specified field. +func requiredIf(fl FieldLevel) bool { + params := parseOneOfParam2(fl.Param()) + if len(params)%2 != 0 { + panic(fmt.Sprintf("Bad param number for required_if %s", fl.FieldName())) + } + for i := 0; i < len(params); i += 2 { + if !requireCheckFieldValue(fl, params[i], params[i+1], false) { + return true + } + } + return hasValue(fl) +} + +// requiredUnless is the validation function +// The field under validation must be present and not empty only unless all the other specified fields are equal to the value following with the specified field. +func requiredUnless(fl FieldLevel) bool { + params := parseOneOfParam2(fl.Param()) + if len(params)%2 != 0 { + panic(fmt.Sprintf("Bad param number for required_unless %s", fl.FieldName())) + } + + for i := 0; i < len(params); i += 2 { + if requireCheckFieldValue(fl, params[i], params[i+1], false) { + return true + } + } + return hasValue(fl) +} + // RequiredWith is the validation function // The field under validation must be present and not empty only if any of the other specified fields are present. func requiredWith(fl FieldLevel) bool { diff --git a/doc.go b/doc.go index 4aba75f0..85527e99 100644 --- a/doc.go +++ b/doc.go @@ -245,6 +245,40 @@ ensures the value is not nil. Usage: required +Required If + +The field under validation must be present and not empty only if all +the other specified fields are equal to the value following the specified +field. For strings ensures value is not "". For slices, maps, pointers, +interfaces, channels and functions ensures the value is not nil. + + Usage: required_if + +Examples: + + // require the field if the Field1 is equal to the parameter given: + Usage: required_if=Field1 foobar + + // require the field if the Field1 and Field2 is equal to the value respectively: + Usage: required_if=Field1 foo Field2 bar + +Required Unless + +The field under validation must be present and not empty unless all +the other specified fields are equal to the value following the specified +field. For strings ensures value is not "". For slices, maps, pointers, +interfaces, channels and functions ensures the value is not nil. + + Usage: required_unless + +Examples: + + // require the field unless the Field1 is equal to the parameter given: + Usage: required_unless=Field1 foobar + + // require the field unless the Field1 and Field2 is equal to the value respectively: + Usage: required_unless=Field1 foo Field2 bar + Required With The field under validation must be present and not empty only if any diff --git a/validator_instance.go b/validator_instance.go index 4a89d406..74acec03 100644 --- a/validator_instance.go +++ b/validator_instance.go @@ -27,6 +27,8 @@ const ( requiredWithoutTag = "required_without" requiredWithTag = "required_with" requiredWithAllTag = "required_with_all" + requiredIfTag = "required_if" + requiredUnlessTag = "required_unless" skipValidationTag = "-" diveTag = "dive" keysTag = "keys" @@ -107,7 +109,7 @@ func New() *Validate { switch k { // these require that even if the value is nil that the validation should run, omitempty still overrides this behaviour - case requiredWithTag, requiredWithAllTag, requiredWithoutTag, requiredWithoutAllTag: + case requiredIfTag, requiredUnlessTag, requiredWithTag, requiredWithAllTag, requiredWithoutTag, requiredWithoutAllTag: _ = v.registerValidation(k, wrapFunc(val), true, true) default: // no need to error check here, baked in will always be valid diff --git a/validator_test.go b/validator_test.go index e76a3cd9..e5e6edca 100644 --- a/validator_test.go +++ b/validator_test.go @@ -8695,6 +8695,151 @@ func TestEndsWithValidation(t *testing.T) { } } +func TestRequiredIf(t *testing.T) { + type Inner struct { + Field *string + } + + fieldVal := "test" + test := struct { + Inner *Inner + FieldE string `validate:"omitempty" json:"field_e"` + FieldER string `validate:"required_if=FieldE test" json:"field_er"` + Field1 string `validate:"omitempty" json:"field_1"` + Field2 *string `validate:"required_if=Field1 test" json:"field_2"` + Field3 map[string]string `validate:"required_if=Field2 test" json:"field_3"` + Field4 interface{} `validate:"required_if=Field3 1" json:"field_4"` + Field5 int `validate:"required_if=Inner.Field test" json:"field_5"` + Field6 uint `validate:"required_if=Field5 1" json:"field_6"` + Field7 float32 `validate:"required_if=Field6 1" json:"field_7"` + Field8 float64 `validate:"required_if=Field7 1.0" json:"field_8"` + }{ + Inner: &Inner{Field: &fieldVal}, + Field2: &fieldVal, + Field3: map[string]string{"key": "val"}, + Field4: "test", + Field5: 2, + } + + validate := New() + + errs := validate.Struct(test) + Equal(t, errs, nil) + + test2 := struct { + Inner *Inner + Inner2 *Inner + FieldE string `validate:"omitempty" json:"field_e"` + FieldER string `validate:"required_if=FieldE test" json:"field_er"` + Field1 string `validate:"omitempty" json:"field_1"` + Field2 *string `validate:"required_if=Field1 test" json:"field_2"` + Field3 map[string]string `validate:"required_if=Field2 test" json:"field_3"` + Field4 interface{} `validate:"required_if=Field2 test" json:"field_4"` + Field5 string `validate:"required_if=Field3 1" json:"field_5"` + Field6 string `validate:"required_if=Inner.Field test" json:"field_6"` + Field7 string `validate:"required_if=Inner2.Field test" json:"field_7"` + }{ + Inner: &Inner{Field: &fieldVal}, + Field2: &fieldVal, + } + + errs = validate.Struct(test2) + NotEqual(t, errs, nil) + + ve := errs.(ValidationErrors) + Equal(t, len(ve), 3) + AssertError(t, errs, "Field3", "Field3", "Field3", "Field3", "required_if") + AssertError(t, errs, "Field4", "Field4", "Field4", "Field4", "required_if") + AssertError(t, errs, "Field6", "Field6", "Field6", "Field6", "required_if") + + defer func() { + if r := recover(); r == nil { + t.Errorf("test3 should have panicked!") + } + }() + + test3 := struct { + Inner *Inner + Field1 string `validate:"required_if=Inner.Field" json:"field_1"` + }{ + Inner: &Inner{Field: &fieldVal}, + } + _ = validate.Struct(test3) +} + +func TestRequiredUnless(t *testing.T) { + type Inner struct { + Field *string + } + + fieldVal := "test" + test := struct { + Inner *Inner + FieldE string `validate:"omitempty" json:"field_e"` + FieldER string `validate:"required_unless=FieldE test" json:"field_er"` + Field1 string `validate:"omitempty" json:"field_1"` + Field2 *string `validate:"required_unless=Field1 test" json:"field_2"` + Field3 map[string]string `validate:"required_unless=Field2 test" json:"field_3"` + Field4 interface{} `validate:"required_unless=Field3 1" json:"field_4"` + Field5 int `validate:"required_unless=Inner.Field test" json:"field_5"` + Field6 uint `validate:"required_unless=Field5 2" json:"field_6"` + Field7 float32 `validate:"required_unless=Field6 0" json:"field_7"` + Field8 float64 `validate:"required_unless=Field7 0.0" json:"field_8"` + }{ + FieldE: "test", + Field2: &fieldVal, + Field3: map[string]string{"key": "val"}, + Field4: "test", + Field5: 2, + } + + validate := New() + + errs := validate.Struct(test) + Equal(t, errs, nil) + + test2 := struct { + Inner *Inner + Inner2 *Inner + FieldE string `validate:"omitempty" json:"field_e"` + FieldER string `validate:"required_unless=FieldE test" json:"field_er"` + Field1 string `validate:"omitempty" json:"field_1"` + Field2 *string `validate:"required_unless=Field1 test" json:"field_2"` + Field3 map[string]string `validate:"required_unless=Field2 test" json:"field_3"` + Field4 interface{} `validate:"required_unless=Field2 test" json:"field_4"` + Field5 string `validate:"required_unless=Field3 0" json:"field_5"` + Field6 string `validate:"required_unless=Inner.Field test" json:"field_6"` + Field7 string `validate:"required_unless=Inner2.Field test" json:"field_7"` + }{ + Inner: &Inner{Field: &fieldVal}, + FieldE: "test", + Field1: "test", + } + + errs = validate.Struct(test2) + NotEqual(t, errs, nil) + + ve := errs.(ValidationErrors) + Equal(t, len(ve), 3) + AssertError(t, errs, "Field3", "Field3", "Field3", "Field3", "required_unless") + AssertError(t, errs, "Field4", "Field4", "Field4", "Field4", "required_unless") + AssertError(t, errs, "Field7", "Field7", "Field7", "Field7", "required_unless") + + defer func() { + if r := recover(); r == nil { + t.Errorf("test3 should have panicked!") + } + }() + + test3 := struct { + Inner *Inner + Field1 string `validate:"required_unless=Inner.Field" json:"field_1"` + }{ + Inner: &Inner{Field: &fieldVal}, + } + _ = validate.Struct(test3) +} + func TestRequiredWith(t *testing.T) { type Inner struct { Field *string