Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow lossless integer to float conversions #325

Merged
merged 2 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 30 additions & 3 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ type Primitive struct {
context Key
}

// The significand precision for float32 and float64 is 24 and 53 bits; this is
// the range a natural number can be stored in a float without loss of data.
const (
maxSafeFloat32Int = 16777215 // 2^24-1
maxSafeFloat64Int = 9007199254740991 // 2^53-1
)

// PrimitiveDecode is just like the other `Decode*` functions, except it
// decodes a TOML value that has already been parsed. Valid primitive values
// can *only* be obtained from values filled by the decoder functions,
Expand Down Expand Up @@ -217,9 +224,7 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
return e("unsupported type %s", rv.Type())
}
return md.unifyAnything(data, rv)
case reflect.Float32:
fallthrough
case reflect.Float64:
case reflect.Float32, reflect.Float64:
return md.unifyFloat64(data, rv)
}
return e("unsupported type %s", rv.Kind())
Expand Down Expand Up @@ -357,6 +362,9 @@ func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
if num, ok := data.(float64); ok {
switch rv.Kind() {
case reflect.Float32:
if num < -math.MaxFloat32 || num > math.MaxFloat32 {
return e("value %f is out of range for float32", num)
}
fallthrough
case reflect.Float64:
rv.SetFloat(num)
Expand All @@ -365,6 +373,25 @@ func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
}
return nil
}

if num, ok := data.(int64); ok {
switch rv.Kind() {
case reflect.Float32:
if num < -maxSafeFloat32Int || num > maxSafeFloat32Int {
return e("value %d is out of range for float32", num)
}
fallthrough
case reflect.Float64:
if num < -maxSafeFloat64Int || num > maxSafeFloat64Int {
return e("value %d is out of range for float64", num)
}
rv.SetFloat(float64(num))
default:
panic("bug")
}
return nil
}

return badtype("float", data)
}

Expand Down
45 changes: 45 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package toml
import (
"fmt"
"io/ioutil"
"math"
"os"
"reflect"
"strings"
Expand Down Expand Up @@ -266,6 +267,50 @@ func TestDecodeIntOverflow(t *testing.T) {
}
}

func TestDecodeFloatOverflow(t *testing.T) {
tests := []struct {
value string
overflow bool
}{
{fmt.Sprintf(`F32 = %f`, math.MaxFloat64), true},
{fmt.Sprintf(`F32 = %f`, -math.MaxFloat64), true},
{fmt.Sprintf(`F32 = %f`, math.MaxFloat32*1.1), true},
{fmt.Sprintf(`F32 = %f`, -math.MaxFloat32*1.1), true},
{fmt.Sprintf(`F32 = %d`, maxSafeFloat32Int+1), true},
{fmt.Sprintf(`F32 = %d`, -maxSafeFloat32Int-1), true},
{fmt.Sprintf(`F64 = %d`, maxSafeFloat64Int+1), true},
{fmt.Sprintf(`F64 = %d`, -maxSafeFloat64Int-1), true},

{fmt.Sprintf(`F32 = %f`, math.MaxFloat32), false},
{fmt.Sprintf(`F32 = %f`, -math.MaxFloat32), false},
{fmt.Sprintf(`F32 = %d`, maxSafeFloat32Int), false},
{fmt.Sprintf(`F32 = %d`, -maxSafeFloat32Int), false},
{fmt.Sprintf(`F64 = %f`, math.MaxFloat64), false},
{fmt.Sprintf(`F64 = %f`, -math.MaxFloat64), false},
{fmt.Sprintf(`F64 = %f`, math.MaxFloat32), false},
{fmt.Sprintf(`F64 = %f`, -math.MaxFloat32), false},
{fmt.Sprintf(`F64 = %d`, maxSafeFloat64Int), false},
{fmt.Sprintf(`F64 = %d`, -maxSafeFloat64Int), false},
}

for _, tt := range tests {
t.Run("", func(t *testing.T) {
var tab struct {
F32 float32
F64 float64
}
_, err := Decode(tt.value, &tab)

if tt.overflow && err == nil {
t.Fatal("expected error, but err is nil")
}
if (tt.overflow && !errorContains(err, "out of range")) || (!tt.overflow && err != nil) {
t.Fatalf("unexpected error:\n%v", err)
}
})
}
}

func TestDecodeSizedInts(t *testing.T) {
type table struct {
U8 uint8
Expand Down