Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for 'money' type in bulk import #430

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 35 additions & 4 deletions bulkcopy.go
Expand Up @@ -323,7 +323,15 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)

switch col.ti.TypeId {

case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN:
case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN, typeMoney, typeMoneyN:
// Note: typeMoney is really int64 with a hard-coded fixed
// point convention (123456 is treated as 12.3456). In bulk
// insert it is treated as int64, and here we expect the
// caller to pass it as the underlying int64. This may be a
// bit inconsistent vs. the []byte that comes back from the
// driver on SELECT for money, but at least this solution
// allows for the possibility of doing a bulk insert.

var intvalue int64

switch val := val.(type) {
Expand All @@ -334,12 +342,18 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
case int64:
intvalue = val
default:
err = fmt.Errorf("mssql: invalid type for int column")
if col.ti.TypeId == typeMoney || col.ti.TypeId == typeMoneyN {
err = fmt.Errorf("mssql: please pass money values as int64 for bulk copy (int64 of 12345 turns into money '1.2345')")
} else {
err = fmt.Errorf("mssql: invalid type for int column")
}
return
}

res.buffer = make([]byte, res.ti.Size)
if col.ti.Size == 1 {
if col.ti.TypeId == typeMoney || col.ti.TypeId == typeMoneyN {
encodeMoney(res.buffer, intvalue)
} else if col.ti.Size == 1 {
res.buffer[0] = byte(intvalue)
} else if col.ti.Size == 2 {
binary.LittleEndian.PutUint16(res.buffer, uint16(intvalue))
Expand Down Expand Up @@ -453,7 +467,6 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
err = fmt.Errorf("mssql: invalid type for datetime column: %s", val)
}

// case typeMoney, typeMoney4, typeMoneyN:
case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
var value float64
switch v := val.(type) {
Expand Down Expand Up @@ -547,6 +560,24 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)

}

// encodeMoney turns a 64-bit integer into the TDS wire format for the
// 'money' type in mssql. The byte ordering was deduced from
// decodeMoney in types.go; could not find it explicitly in the TDS
// documentation. The format has been tested on the wire against real
// SQL Server.
func encodeMoney(out []byte, value int64) {
var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], uint64(value))
out[4] = buf[0]
out[5] = buf[1]
out[6] = buf[2]
out[7] = buf[3]
out[0] = buf[4]
out[1] = buf[5]
out[2] = buf[6]
out[3] = buf[7]
}

func (b *Bulk) dlogf(format string, v ...interface{}) {
if b.Debug {
b.cn.sess.log.Printf(format, v...)
Expand Down
56 changes: 50 additions & 6 deletions bulkcopy_test.go
Expand Up @@ -23,6 +23,11 @@ func TestBulkcopy(t *testing.T) {
val interface{}
}

type differentExpected struct {
input interface{}
expected interface{}
}

tableName := "#table_test"
geom, _ := hex.DecodeString("E6100000010C00000000000034400000000000004440")
bin, _ := hex.DecodeString("ba8b7782168d4033a299333aec17bd33")
Expand Down Expand Up @@ -58,7 +63,6 @@ func TestBulkcopy(t *testing.T) {
{"test_geom", geom},
{"test_uniqueidentifier", []byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}},
// {"test_smallmoney", 1234.56},
// {"test_money", 1234.56},
{"test_decimal_18_0", 1234.0001},
{"test_decimal_9_2", 1234.560001},
{"test_decimal_20_0", 1234.0001},
Expand All @@ -68,6 +72,27 @@ func TestBulkcopy(t *testing.T) {
{"test_varbinary_max", bin},
{"test_binary", []byte("1")},
{"test_binary_16", bin},

// money must be input as int64 to bulk insert, but scans back as a string on SELECT, so use `differentExpected` to provide
// different input and expected output

// First test: We do some byte shuffling for the money type, so make sure every byte is unique in the test.
{"test_money_1", differentExpected{
int64(-(0x01<<56 | 0x02<<48 | 0x03<<40 | 0x04<<32 | 0x05<<24 | 0x06<<16 | 0x07<<8 | 0x08)), // evaluates to 72623859790382856
[]byte("-7262385979038.2856")}},
// maximum positive, minimum negative, and zero values
{"test_money_2", differentExpected{math.MaxInt64, []byte("922337203685477.5807")}},
{"test_money_3", differentExpected{math.MinInt64, []byte("-922337203685477.5808")}},
{"test_money_4", differentExpected{0, []byte("0.0000")}},

{"test_money_n_1", differentExpected{
int64(-(0x01<<56 | 0x02<<48 | 0x03<<40 | 0x04<<32 | 0x05<<24 | 0x06<<16 | 0x07<<8 | 0x08)), // evaluates to 72623859790382856
[]byte("-7262385979038.2856")}},
// maximum positive, minimum negative, and zero values
{"test_money_n_2", differentExpected{math.MaxInt64, []byte("922337203685477.5807")}},
{"test_money_n_3", differentExpected{math.MinInt64, []byte("-922337203685477.5808")}},
{"test_money_n_4", differentExpected{0, []byte("0.0000")}},
{"test_money_n_5", nil},
}

columns := make([]string, len(testValues))
Expand All @@ -77,7 +102,12 @@ func TestBulkcopy(t *testing.T) {

values := make([]interface{}, len(testValues))
for i, val := range testValues {
values[i] = val.val
switch t := val.val.(type) {
case differentExpected:
values[i] = t.input
default:
values[i] = val.val
}
}

pool := open(t)
Expand Down Expand Up @@ -149,8 +179,15 @@ func TestBulkcopy(t *testing.T) {
t.Fatal(err)
}
for i, c := range testValues {
if !compareValue(container[i], c.val) {
t.Errorf("columns %s : expected: %v, got: %v\n", c.colname, c.val, container[i])
var expected interface{}
switch t := c.val.(type) {
case differentExpected:
expected = t.expected
default:
expected = c.val
}
if !compareValue(container[i], expected) {
t.Errorf("columns %s : expected: %v, got: %v\n", c.colname, string(expected.([]byte)), string(container[i].([]byte)))
}
}
}
Expand Down Expand Up @@ -203,8 +240,6 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str
[test_datetime2_3] [datetime2](3) NULL,
[test_datetime2_7] [datetime2](7) NULL,
[test_date] [date] NULL,
[test_smallmoney] [smallmoney] NULL,
[test_money] [money] NULL,
[test_tinyint] [tinyint] NULL,
[test_smallint] [smallint] NOT NULL,
[test_smallintn] [smallint] NULL,
Expand All @@ -224,6 +259,15 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str
[test_varbinary_max] VARBINARY(max) NOT NULL,
[test_binary] BINARY NOT NULL,
[test_binary_16] BINARY(16) NOT NULL,
[test_money_1] MONEY NOT NULL,
[test_money_2] MONEY NOT NULL,
[test_money_3] MONEY NOT NULL,
[test_money_4] MONEY NOT NULL,
[test_money_n_1] MONEY NULL,
[test_money_n_2] MONEY NULL,
[test_money_n_3] MONEY NULL,
[test_money_n_4] MONEY NULL,
[test_money_n_5] MONEY NULL
CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED
(
[id] ASC
Expand Down