Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor out all mssql defined data types into a mssqltypes subpackage #518

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 10 additions & 10 deletions bulkcopy.go
Expand Up @@ -10,7 +10,7 @@ import (
"strings"
"time"

"github.com/denisenkom/go-mssqldb/internal/decimal"
"github.com/denisenkom/go-mssqldb/internal/mssqltypes"
)

type Bulk struct {
Expand Down Expand Up @@ -490,24 +490,24 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
prec := col.ti.Prec
scale := col.ti.Scale
var dec decimal.Decimal
var dec mssqltypes.Decimal
switch v := val.(type) {
case int:
dec = decimal.Int64ToDecimalScale(int64(v), 0)
dec = mssqltypes.Int64ToDecimalScale(int64(v), 0)
case int8:
dec = decimal.Int64ToDecimalScale(int64(v), 0)
dec = mssqltypes.Int64ToDecimalScale(int64(v), 0)
case int16:
dec = decimal.Int64ToDecimalScale(int64(v), 0)
dec = mssqltypes.Int64ToDecimalScale(int64(v), 0)
case int32:
dec = decimal.Int64ToDecimalScale(int64(v), 0)
dec = mssqltypes.Int64ToDecimalScale(int64(v), 0)
case int64:
dec = decimal.Int64ToDecimalScale(int64(v), 0)
dec = mssqltypes.Int64ToDecimalScale(int64(v), 0)
case float32:
dec, err = decimal.Float64ToDecimalScale(float64(v), scale)
dec, err = mssqltypes.Float64ToDecimalScale(float64(v), scale)
case float64:
dec, err = decimal.Float64ToDecimalScale(float64(v), scale)
dec, err = mssqltypes.Float64ToDecimalScale(float64(v), scale)
case string:
dec, err = decimal.StringToDecimalScale(v, scale)
dec, err = mssqltypes.StringToDecimalScale(v, scale)
default:
return res, fmt.Errorf("unknown value for decimal: %T %#v", v, v)
}
Expand Down
10 changes: 5 additions & 5 deletions datetimeoffset_example_test.go
Expand Up @@ -9,8 +9,8 @@ import (
"log"
"time"

"github.com/denisenkom/go-mssqldb/internal/mssqltypes"
"github.com/golang-sql/civil"
"github.com/denisenkom/go-mssqldb"
)

// This example shows how to insert and retrieve date and time types data
Expand Down Expand Up @@ -56,9 +56,9 @@ func insertDateTime(db *sql.DB) {
var timeCol civil.Time = civil.TimeOf(tin)
var dateCol civil.Date = civil.DateOf(tin)
var smalldatetimeCol string = "2006-01-02 22:04:00"
var datetimeCol mssql.DateTime1 = mssql.DateTime1(tin)
var datetimeCol mssqltypes.DateTime1 = mssqltypes.DateTime1(tin)
var datetime2Col civil.DateTime = civil.DateTimeOf(tin)
var datetimeoffsetCol mssql.DateTimeOffset = mssql.DateTimeOffset(tin)
var datetimeoffsetCol mssqltypes.DateTimeOffset = mssqltypes.DateTimeOffset(tin)
_, err = stmt.Exec(timeCol, dateCol, smalldatetimeCol, datetimeCol, datetime2Col, datetimeoffsetCol)
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -103,8 +103,8 @@ func retrieveDateTimeOutParam(db *sql.DB) {
log.Fatal(err)
}
var (
timeOutParam, datetime2OutParam, datetimeoffsetOutParam mssql.DateTimeOffset
dateOutParam, datetimeOutParam mssql.DateTime1
timeOutParam, datetime2OutParam, datetimeoffsetOutParam mssqltypes.DateTimeOffset
dateOutParam, datetimeOutParam mssqltypes.DateTime1
smalldatetimeOutParam string
)
_, err = db.Exec("OutDatetimeProc",
Expand Down
9 changes: 9 additions & 0 deletions error.go
Expand Up @@ -19,6 +19,7 @@ type Error struct {
LineNo int32
}

// Error returns the SQL Server error message.
func (e Error) Error() string {
return "mssql: " + e.Message
}
Expand All @@ -28,34 +29,42 @@ func (e Error) SQLErrorNumber() int32 {
return e.Number
}

// SQLErrorState returns the SQL Server error state.
func (e Error) SQLErrorState() uint8 {
return e.State
}

// SQLErrorClass returns the SQL Server error class.
func (e Error) SQLErrorClass() uint8 {
return e.Class
}

// SQLErrorMessage returns the SQL Server error message.
func (e Error) SQLErrorMessage() string {
return e.Message
}

// SQLErrorServerName returns the SQL Server name.
func (e Error) SQLErrorServerName() string {
return e.ServerName
}

// SQLErrorProcName returns the procedure name.
func (e Error) SQLErrorProcName() string {
return e.ProcName
}

// SQLErrorLineNo returns the error line number.
func (e Error) SQLErrorLineNo() int32 {
return e.LineNo
}

// StreamError represents TDS stream error.
type StreamError struct {
Message string
}

// Error returns the TDS stream error message.
func (e StreamError) Error() string {
return e.Message
}
Expand Down
@@ -1,4 +1,4 @@
package decimal
package mssqltypes

import (
"encoding/binary"
Expand Down
@@ -1,4 +1,4 @@
package decimal
package mssqltypes

import (
"math"
Expand Down
20 changes: 20 additions & 0 deletions internal/mssqltypes/mssqltypes.go
@@ -0,0 +1,20 @@
// +build go1.9

package mssqltypes

import "time"

// VarChar parameter types.
type VarChar string

// NVarCharMax encodes parameters to NVarChar(max) SQL type.
type NVarCharMax string

// VarCharMax encodes parameter to VarChar(max) SQL type.
type VarCharMax string

// DateTime1 encodes parameters to original DateTime SQL types.
type DateTime1 time.Time

// DateTimeOffset encodes parameters to DateTimeOffset, preserving the UTC offset.
type DateTimeOffset time.Time
@@ -1,4 +1,4 @@
package mssql
package mssqltypes

import (
"database/sql/driver"
Expand All @@ -7,8 +7,10 @@ import (
"fmt"
)

// UniqueIdentifier encodes parameters to Uniqueidentifier SQL type.
type UniqueIdentifier [16]byte

// Scan converts v to UniqueIdentifier
func (u *UniqueIdentifier) Scan(v interface{}) error {
reverse := func(b []byte) {
for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 {
Expand Down Expand Up @@ -52,6 +54,7 @@ func (u *UniqueIdentifier) Scan(v interface{}) error {
}
}

// Value converts UniqueIdentifier to bytes
func (u UniqueIdentifier) Value() (driver.Value, error) {
reverse := func(b []byte) {
for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 {
Expand All @@ -69,6 +72,7 @@ func (u UniqueIdentifier) Value() (driver.Value, error) {
return raw, nil
}

// String converts UniqueIdentifier to string
func (u UniqueIdentifier) String() string {
return fmt.Sprintf("%X-%X-%X-%X-%X", u[0:4], u[4:6], u[6:8], u[8:10], u[10:])
}
Expand Down
@@ -1,4 +1,4 @@
package mssql
package mssqltypes

import (
"bytes"
Expand Down
33 changes: 11 additions & 22 deletions mssql_go19.go
Expand Up @@ -11,6 +11,7 @@ import (
"time"

// "github.com/cockroachdb/apd"
"github.com/denisenkom/go-mssqldb/internal/mssqltypes"
"github.com/golang-sql/civil"
)

Expand All @@ -26,29 +27,17 @@ type MssqlStmt = Stmt // Deprecated: users should transition to th

var _ driver.NamedValueChecker = &Conn{}

// VarChar parameter types.
type VarChar string

type NVarCharMax string
type VarCharMax string

// DateTime1 encodes parameters to original DateTime SQL types.
type DateTime1 time.Time

// DateTimeOffset encodes parameters to DateTimeOffset, preserving the UTC offset.
type DateTimeOffset time.Time

func convertInputParameter(val interface{}) (interface{}, error) {
switch v := val.(type) {
case VarChar:
case mssqltypes.VarChar:
return val, nil
case NVarCharMax:
case mssqltypes.NVarCharMax:
return val, nil
case VarCharMax:
case mssqltypes.VarCharMax:
return val, nil
case DateTime1:
case mssqltypes.DateTime1:
return val, nil
case DateTimeOffset:
case mssqltypes.DateTimeOffset:
return val, nil
case civil.Date:
return val, nil
Expand Down Expand Up @@ -123,24 +112,24 @@ func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {

func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) {
switch val := val.(type) {
case VarChar:
case mssqltypes.VarChar:
res.ti.TypeId = typeBigVarChar
res.buffer = []byte(val)
res.ti.Size = len(res.buffer)
case VarCharMax:
case mssqltypes.VarCharMax:
res.ti.TypeId = typeBigVarChar
res.buffer = []byte(val)
res.ti.Size = 0 // currently zero forces varchar(max)
case NVarCharMax:
case mssqltypes.NVarCharMax:
res.ti.TypeId = typeNVarChar
res.buffer = str2ucs2(string(val))
res.ti.Size = 0 // currently zero forces nvarchar(max)
case DateTime1:
case mssqltypes.DateTime1:
t := time.Time(val)
res.ti.TypeId = typeDateTimeN
res.buffer = encodeDateTime(t)
res.ti.Size = len(res.buffer)
case DateTimeOffset:
case mssqltypes.DateTimeOffset:
res.ti.TypeId = typeDateTimeOffsetN
res.ti.Scale = 7
res.buffer = encodeDateTimeOffset(time.Time(val), int(res.ti.Scale))
Expand Down
15 changes: 8 additions & 7 deletions queries_go110_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/denisenkom/go-mssqldb/internal/mssqltypes"
"github.com/golang-sql/civil"
)

Expand Down Expand Up @@ -80,14 +81,14 @@ select
;
`,
sql.Named("nv", "base type nvarchar"),
sql.Named("v", VarChar("base type varchar")),
sql.Named("nvcm", NVarCharMax(strings.Repeat("x", 5000))),
sql.Named("vcm", VarCharMax(strings.Repeat("x", 5000))),
sql.Named("dt1", DateTime1(tin)),
sql.Named("v", mssqltypes.VarChar("base type varchar")),
sql.Named("nvcm", mssqltypes.NVarCharMax(strings.Repeat("x", 5000))),
sql.Named("vcm", mssqltypes.VarCharMax(strings.Repeat("x", 5000))),
sql.Named("dt1", mssqltypes.DateTime1(tin)),
sql.Named("dt2", civil.DateTimeOf(tin)),
sql.Named("d", civil.DateOf(tin)),
sql.Named("tm", civil.TimeOf(tin)),
sql.Named("dto", DateTimeOffset(tin)),
sql.Named("dto", mssqltypes.DateTimeOffset(tin)),
)
err = row.Scan(&nv, &v, &nvcm, &vcm, &dt1, &dt2, &d, &tm, &dto)
if err != nil {
Expand Down Expand Up @@ -153,11 +154,11 @@ select
sql.Named("nv", sin),
sql.Named("v", sin),
sql.Named("tgo", tin),
sql.Named("dt1", DateTime1(tin)),
sql.Named("dt1", mssqltypes.DateTime1(tin)),
sql.Named("dt2", civil.DateTimeOf(tin)),
sql.Named("d", civil.DateOf(tin)),
sql.Named("tm", civil.TimeOf(tin)),
sql.Named("dto", DateTimeOffset(tin)),
sql.Named("dto", mssqltypes.DateTimeOffset(tin)),
).Scan(&nv, &v, &tgo, &dt1, &dt2, &d, &tm, &dto)
if err != nil {
t.Fatal(err)
Expand Down
20 changes: 11 additions & 9 deletions queries_go19_test.go
Expand Up @@ -10,6 +10,8 @@ import (
"regexp"
"testing"
"time"

"github.com/denisenkom/go-mssqldb/internal/mssqltypes"
)

func TestOutputParam(t *testing.T) {
Expand Down Expand Up @@ -174,8 +176,8 @@ END;
if err != nil {
t.Fatal(err)
}
var datetime_param DateTime1
datetime_param = DateTime1(tin)
var datetime_param mssqltypes.DateTime1
datetime_param = mssqltypes.DateTime1(tin)
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("datetime", sql.Out{Dest: &datetime_param}),
)
Expand Down Expand Up @@ -346,7 +348,7 @@ END;
t.Run("original test", func(t *testing.T) {
var bout int64 = 3
var cout string
var vout VarChar
var vout mssqltypes.VarChar
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("aid", 5),
sql.Named("bid", sql.Out{Dest: &bout}),
Expand Down Expand Up @@ -964,12 +966,12 @@ func TestDateTimeParam19(t *testing.T) {
var emptydate time.Time
mindate1 := time.Date(1753, 1, 1, 0, 0, 0, 0, time.UTC)
maxdate1 := time.Date(9999, 12, 31, 23, 59, 59, 997000000, time.UTC)
testdates1 := []DateTime1{
DateTime1(mindate1),
DateTime1(maxdate1),
DateTime1(time.Date(1752, 12, 31, 23, 59, 59, 997000000, time.UTC)), // just a little below minimum date
DateTime1(time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC)), // just a little over maximum date
DateTime1(emptydate),
testdates1 := []mssqltypes.DateTime1{
mssqltypes.DateTime1(mindate1),
mssqltypes.DateTime1(maxdate1),
mssqltypes.DateTime1(time.Date(1752, 12, 31, 23, 59, 59, 997000000, time.UTC)), // just a little below minimum date
mssqltypes.DateTime1(time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC)), // just a little over maximum date
mssqltypes.DateTime1(emptydate),
}

for _, test := range testdates1 {
Expand Down
6 changes: 4 additions & 2 deletions queries_test.go
Expand Up @@ -14,6 +14,8 @@ import (
"sync"
"testing"
"time"

"github.com/denisenkom/go-mssqldb/internal/mssqltypes"
)

func driverWithProcess(t *testing.T) *Driver {
Expand Down Expand Up @@ -836,7 +838,7 @@ func TestUniqueIdentifierParam(t *testing.T) {
uuid interface{}
}

expected := UniqueIdentifier{0x01, 0x23, 0x45, 0x67,
expected := mssqltypes.UniqueIdentifier{0x01, 0x23, 0x45, 0x67,
0x89, 0xAB,
0xCD, 0xEF,
0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF,
Expand All @@ -856,7 +858,7 @@ func TestUniqueIdentifierParam(t *testing.T) {

for _, test := range values {
t.Run(test.name, func(t *testing.T) {
var uuid2 UniqueIdentifier
var uuid2 mssqltypes.UniqueIdentifier
err := conn.QueryRow("select @p1", test.uuid).Scan(&uuid2)
if err != nil {
t.Fatal("select / scan failed", err.Error())
Expand Down