Skip to content

Commit

Permalink
assert: Fix EqualValues to handle overflow/underflow
Browse files Browse the repository at this point in the history
The underlying function ObjectsAreEqualValues did not handle
overflow/underflow of values while converting one type to another
for comparison. For example:

    EqualValues(t, int(270), int8(14))

would return true, even though the values are not equal. Because, when
you convert int(270) to int8, it overflows and becomes 14 (270 % 256 = 14)

This commit fixes that by making sure that the conversion always happens
from the smaller type to the larger type, and then comparing the values.
Additionally, this commit also seperates out the test cases of
ObjectsAreEqualValues from TestObjectsAreEqual.

Fixes #1462
  • Loading branch information
arjunmahishi committed Feb 14, 2024
1 parent c3b0c9b commit 4c4d011
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 20 deletions.
35 changes: 29 additions & 6 deletions assert/assertions.go
Expand Up @@ -165,17 +165,40 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool {
return true
}

actualType := reflect.TypeOf(actual)
if actualType == nil {
expectedValue := reflect.ValueOf(expected)
actualValue := reflect.ValueOf(actual)
if !expectedValue.IsValid() || !actualValue.IsValid() {
return false
}
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {

expectedType := expectedValue.Type()
actualType := actualValue.Type()
if !expectedType.ConvertibleTo(actualType) {
return false
}

if !isNumericType(expectedType) || !isNumericType(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
return reflect.DeepEqual(
expectedValue.Convert(actualType).Interface(), actual,
)
}

return false
// If BOTH values are numeric, there are chances of false positives due
// to overflow or underflow. So, we need to make sure to always convert
// the smaller type to a larger type before comparing.
if expectedType.Size() >= actualType.Size() {
return actualValue.Convert(expectedType).Interface() == expected
}

return expectedValue.Convert(actualType).Interface() == actual
}

// isNumericType returns true if the type is one of:
// int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64,
// float32, float64, complex64, complex128
func isNumericType(t reflect.Type) bool {
return t.Kind() >= reflect.Int && t.Kind() <= reflect.Complex128
}

/* CallerInfo is necessary because the assert functions use the testing object
Expand Down
46 changes: 32 additions & 14 deletions assert/assertions_test.go
Expand Up @@ -135,24 +135,42 @@ func TestObjectsAreEqual(t *testing.T) {

})
}
}

// Cases where type differ but values are equal
if !ObjectsAreEqualValues(uint32(10), int32(10)) {
t.Error("ObjectsAreEqualValues should return true")
}
if ObjectsAreEqualValues(0, nil) {
t.Fail()
}
if ObjectsAreEqualValues(nil, 0) {
t.Fail()
}
func TestObjectsAreEqualValues(t *testing.T) {
now := time.Now()

tm := time.Now()
tz := tm.In(time.Local)
if !ObjectsAreEqualValues(tm, tz) {
t.Error("ObjectsAreEqualValues should return true for time.Time objects with different time zones")
cases := []struct {
expected interface{}
actual interface{}
result bool
}{
{uint32(10), int32(10), true},
{0, nil, false},
{nil, 0, false},
{now, now.In(time.Local), true}, // should be time zone independent
{int(270), int8(14), false}, // should handle overflow/underflow
{int8(14), int(270), false},
{[]int{270, 270}, []int8{14, 14}, false},
{complex128(1e+100 + 1e+100i), complex64(complex(math.Inf(0), math.Inf(0))), false},
{complex64(complex(math.Inf(0), math.Inf(0))), complex128(1e+100 + 1e+100i), false},
{complex128(1e+100 + 1e+100i), 270, false},
{270, complex128(1e+100 + 1e+100i), false},
{complex128(1e+100 + 1e+100i), 3.14, false},
{3.14, complex128(1e+100 + 1e+100i), false},
{complex128(1e+10 + 1e+10i), complex64(1e+10 + 1e+10i), true},
{complex64(1e+10 + 1e+10i), complex128(1e+10 + 1e+10i), true},
}

for _, c := range cases {
t.Run(fmt.Sprintf("ObjectsAreEqualValues(%#v, %#v)", c.expected, c.actual), func(t *testing.T) {
res := ObjectsAreEqualValues(c.expected, c.actual)

if res != c.result {
t.Errorf("ObjectsAreEqualValues(%#v, %#v) should return %#v", c.expected, c.actual, c.result)
}
})
}
}

type Nested struct {
Expand Down

0 comments on commit 4c4d011

Please sign in to comment.