Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor out all mssql defined data types into a mssqltypes subpackage #518

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 6 additions & 4 deletions buf.go
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/binary"
"errors"
"io"

"github.com/denisenkom/go-mssqldb/internal/mssqlerror"
)

type packetType uint8
Expand Down Expand Up @@ -186,15 +188,15 @@ func (r *tdsBuffer) ReadByte() (res byte, err error) {
func (r *tdsBuffer) byte() byte {
b, err := r.ReadByte()
if err != nil {
badStreamPanic(err)
mssqlerror.BadStreamPanic(err)
}
return b
}

func (r *tdsBuffer) ReadFull(buf []byte) {
_, err := io.ReadFull(r, buf[:])
if err != nil {
badStreamPanic(err)
mssqlerror.BadStreamPanic(err)
}
}

Expand Down Expand Up @@ -227,15 +229,15 @@ func (r *tdsBuffer) BVarChar() string {
func readBVarCharOrPanic(r io.Reader) string {
s, err := readBVarChar(r)
if err != nil {
badStreamPanic(err)
mssqlerror.BadStreamPanic(err)
}
return s
}

func readUsVarCharOrPanic(r io.Reader) string {
s, err := readUsVarChar(r)
if err != nil {
badStreamPanic(err)
mssqlerror.BadStreamPanic(err)
}
return s
}
Expand Down
20 changes: 10 additions & 10 deletions bulkcopy.go
Expand Up @@ -10,7 +10,7 @@ import (
"strings"
"time"

"github.com/denisenkom/go-mssqldb/internal/decimal"
"github.com/denisenkom/go-mssqldb/internal/mssqltypes"
)

type Bulk struct {
Expand Down Expand Up @@ -490,24 +490,24 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
prec := col.ti.Prec
scale := col.ti.Scale
var dec decimal.Decimal
var dec mssqltypes.Decimal
switch v := val.(type) {
case int:
dec = decimal.Int64ToDecimalScale(int64(v), 0)
dec = mssqltypes.Int64ToDecimalScale(int64(v), 0)
case int8:
dec = decimal.Int64ToDecimalScale(int64(v), 0)
dec = mssqltypes.Int64ToDecimalScale(int64(v), 0)
case int16:
dec = decimal.Int64ToDecimalScale(int64(v), 0)
dec = mssqltypes.Int64ToDecimalScale(int64(v), 0)
case int32:
dec = decimal.Int64ToDecimalScale(int64(v), 0)
dec = mssqltypes.Int64ToDecimalScale(int64(v), 0)
case int64:
dec = decimal.Int64ToDecimalScale(int64(v), 0)
dec = mssqltypes.Int64ToDecimalScale(int64(v), 0)
case float32:
dec, err = decimal.Float64ToDecimalScale(float64(v), scale)
dec, err = mssqltypes.Float64ToDecimalScale(float64(v), scale)
case float64:
dec, err = decimal.Float64ToDecimalScale(float64(v), scale)
dec, err = mssqltypes.Float64ToDecimalScale(float64(v), scale)
case string:
dec, err = decimal.StringToDecimalScale(v, scale)
dec, err = mssqltypes.StringToDecimalScale(v, scale)
default:
return res, fmt.Errorf("unknown value for decimal: %T %#v", v, v)
}
Expand Down
10 changes: 5 additions & 5 deletions datetimeoffset_example_test.go
Expand Up @@ -9,8 +9,8 @@ import (
"log"
"time"

"github.com/denisenkom/go-mssqldb/internal/mssqltypes"
"github.com/golang-sql/civil"
"github.com/denisenkom/go-mssqldb"
)

// This example shows how to insert and retrieve date and time types data
Expand Down Expand Up @@ -56,9 +56,9 @@ func insertDateTime(db *sql.DB) {
var timeCol civil.Time = civil.TimeOf(tin)
var dateCol civil.Date = civil.DateOf(tin)
var smalldatetimeCol string = "2006-01-02 22:04:00"
var datetimeCol mssql.DateTime1 = mssql.DateTime1(tin)
var datetimeCol mssqltypes.DateTime1 = mssqltypes.DateTime1(tin)
var datetime2Col civil.DateTime = civil.DateTimeOf(tin)
var datetimeoffsetCol mssql.DateTimeOffset = mssql.DateTimeOffset(tin)
var datetimeoffsetCol mssqltypes.DateTimeOffset = mssqltypes.DateTimeOffset(tin)
_, err = stmt.Exec(timeCol, dateCol, smalldatetimeCol, datetimeCol, datetime2Col, datetimeoffsetCol)
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -103,8 +103,8 @@ func retrieveDateTimeOutParam(db *sql.DB) {
log.Fatal(err)
}
var (
timeOutParam, datetime2OutParam, datetimeoffsetOutParam mssql.DateTimeOffset
dateOutParam, datetimeOutParam mssql.DateTime1
timeOutParam, datetime2OutParam, datetimeoffsetOutParam mssqltypes.DateTimeOffset
dateOutParam, datetimeOutParam mssqltypes.DateTime1
smalldatetimeOutParam string
)
_, err = db.Exec("OutDatetimeProc",
Expand Down
17 changes: 14 additions & 3 deletions error.go → internal/mssqlerror/error.go
@@ -1,4 +1,4 @@
package mssql
package mssqlerror

import (
"fmt"
Expand All @@ -19,6 +19,7 @@ type Error struct {
LineNo int32
}

// Error returns the SQL Server error message.
yukiwongky marked this conversation as resolved.
Show resolved Hide resolved
func (e Error) Error() string {
return "mssql: " + e.Message
}
Expand All @@ -28,34 +29,42 @@ func (e Error) SQLErrorNumber() int32 {
return e.Number
}

// SQLErrorState returns the SQL Server error state.
func (e Error) SQLErrorState() uint8 {
return e.State
}

// SQLErrorClass returns the SQL Server error class.
func (e Error) SQLErrorClass() uint8 {
return e.Class
}

// SQLErrorMessage returns the SQL Server error message.
func (e Error) SQLErrorMessage() string {
return e.Message
}

// SQLErrorServerName returns the SQL Server name.
func (e Error) SQLErrorServerName() string {
return e.ServerName
}

// SQLErrorProcName returns the procedure name.
func (e Error) SQLErrorProcName() string {
return e.ProcName
}

// SQLErrorLineNo returns the error line number.
func (e Error) SQLErrorLineNo() int32 {
return e.LineNo
}

// StreamError represents TDS stream error.
type StreamError struct {
Message string
}

// Error returns the TDS stream error message.
func (e StreamError) Error() string {
return e.Message
}
Expand All @@ -64,10 +73,12 @@ func streamErrorf(format string, v ...interface{}) StreamError {
return StreamError{"Invalid TDS stream: " + fmt.Sprintf(format, v...)}
}

func badStreamPanic(err error) {
// BadStreamPanic calls panic with err.
func BadStreamPanic(err error) {
panic(err)
}

func badStreamPanicf(format string, v ...interface{}) {
// BadStreamPanicf calls panic with a formatted error message as an invalid TDS stream error.
func BadStreamPanicf(format string, v ...interface{}) {
panic(streamErrorf(format, v...))
}
@@ -1,4 +1,4 @@
package decimal
package mssqltypes

import (
"encoding/binary"
Expand Down
@@ -1,4 +1,4 @@
package decimal
package mssqltypes

import (
"math"
Expand Down
20 changes: 20 additions & 0 deletions internal/mssqltypes/mssqltypes.go
@@ -0,0 +1,20 @@
// +build go1.9

package mssqltypes

import "time"

// VarChar parameter types.
type VarChar string

// NVarCharMax encodes parameters to NVarChar(max) SQL type.
type NVarCharMax string

// VarCharMax encodes parameter to VarChar(max) SQL type.
type VarCharMax string

// DateTime1 encodes parameters to original DateTime SQL types.
type DateTime1 time.Time

// DateTimeOffset encodes parameters to DateTimeOffset, preserving the UTC offset.
type DateTimeOffset time.Time
@@ -1,4 +1,4 @@
package mssql
package mssqltypes

import (
"database/sql/driver"
Expand All @@ -7,8 +7,10 @@ import (
"fmt"
)

// UniqueIdentifier encodes parameters to Uniqueidentifier SQL type.
type UniqueIdentifier [16]byte

// Scan converts v to UniqueIdentifier
func (u *UniqueIdentifier) Scan(v interface{}) error {
reverse := func(b []byte) {
for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 {
Expand Down Expand Up @@ -52,6 +54,7 @@ func (u *UniqueIdentifier) Scan(v interface{}) error {
}
}

// Value converts UniqueIdentifier to bytes
func (u UniqueIdentifier) Value() (driver.Value, error) {
reverse := func(b []byte) {
for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 {
Expand All @@ -69,6 +72,7 @@ func (u UniqueIdentifier) Value() (driver.Value, error) {
return raw, nil
}

// String converts UniqueIdentifier to string
func (u UniqueIdentifier) String() string {
return fmt.Sprintf("%X-%X-%X-%X-%X", u[0:4], u[4:6], u[6:8], u[8:10], u[10:])
}
@@ -1,4 +1,4 @@
package mssql
package mssqltypes

import (
"bytes"
Expand Down
3 changes: 2 additions & 1 deletion mssql.go
Expand Up @@ -15,6 +15,7 @@ import (
"time"
"unicode"

"github.com/denisenkom/go-mssqldb/internal/mssqlerror"
"github.com/denisenkom/go-mssqldb/internal/querytext"
)

Expand Down Expand Up @@ -188,7 +189,7 @@ func (c *Conn) checkBadConn(err error) error {
case net.Error:
c.connectionGood = false
return err
case StreamError:
case mssqlerror.StreamError:
c.connectionGood = false
return err
default:
Expand Down
33 changes: 11 additions & 22 deletions mssql_go19.go
Expand Up @@ -11,6 +11,7 @@ import (
"time"

// "github.com/cockroachdb/apd"
"github.com/denisenkom/go-mssqldb/internal/mssqltypes"
"github.com/golang-sql/civil"
)

Expand All @@ -26,29 +27,17 @@ type MssqlStmt = Stmt // Deprecated: users should transition to th

var _ driver.NamedValueChecker = &Conn{}

// VarChar parameter types.
type VarChar string

type NVarCharMax string
type VarCharMax string

// DateTime1 encodes parameters to original DateTime SQL types.
type DateTime1 time.Time

// DateTimeOffset encodes parameters to DateTimeOffset, preserving the UTC offset.
type DateTimeOffset time.Time

func convertInputParameter(val interface{}) (interface{}, error) {
switch v := val.(type) {
case VarChar:
case mssqltypes.VarChar:
return val, nil
case NVarCharMax:
case mssqltypes.NVarCharMax:
return val, nil
case VarCharMax:
case mssqltypes.VarCharMax:
return val, nil
case DateTime1:
case mssqltypes.DateTime1:
return val, nil
case DateTimeOffset:
case mssqltypes.DateTimeOffset:
return val, nil
case civil.Date:
return val, nil
Expand Down Expand Up @@ -123,24 +112,24 @@ func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {

func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) {
switch val := val.(type) {
case VarChar:
case mssqltypes.VarChar:
res.ti.TypeId = typeBigVarChar
res.buffer = []byte(val)
res.ti.Size = len(res.buffer)
case VarCharMax:
case mssqltypes.VarCharMax:
res.ti.TypeId = typeBigVarChar
res.buffer = []byte(val)
res.ti.Size = 0 // currently zero forces varchar(max)
case NVarCharMax:
case mssqltypes.NVarCharMax:
res.ti.TypeId = typeNVarChar
res.buffer = str2ucs2(string(val))
res.ti.Size = 0 // currently zero forces nvarchar(max)
case DateTime1:
case mssqltypes.DateTime1:
t := time.Time(val)
res.ti.TypeId = typeDateTimeN
res.buffer = encodeDateTime(t)
res.ti.Size = len(res.buffer)
case DateTimeOffset:
case mssqltypes.DateTimeOffset:
res.ti.TypeId = typeDateTimeOffsetN
res.ti.Scale = 7
res.buffer = encodeDateTimeOffset(time.Time(val), int(res.ti.Scale))
Expand Down
15 changes: 8 additions & 7 deletions queries_go110_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/denisenkom/go-mssqldb/internal/mssqltypes"
"github.com/golang-sql/civil"
)

Expand Down Expand Up @@ -80,14 +81,14 @@ select
;
`,
sql.Named("nv", "base type nvarchar"),
sql.Named("v", VarChar("base type varchar")),
sql.Named("nvcm", NVarCharMax(strings.Repeat("x", 5000))),
sql.Named("vcm", VarCharMax(strings.Repeat("x", 5000))),
sql.Named("dt1", DateTime1(tin)),
sql.Named("v", mssqltypes.VarChar("base type varchar")),
sql.Named("nvcm", mssqltypes.NVarCharMax(strings.Repeat("x", 5000))),
sql.Named("vcm", mssqltypes.VarCharMax(strings.Repeat("x", 5000))),
sql.Named("dt1", mssqltypes.DateTime1(tin)),
sql.Named("dt2", civil.DateTimeOf(tin)),
sql.Named("d", civil.DateOf(tin)),
sql.Named("tm", civil.TimeOf(tin)),
sql.Named("dto", DateTimeOffset(tin)),
sql.Named("dto", mssqltypes.DateTimeOffset(tin)),
)
err = row.Scan(&nv, &v, &nvcm, &vcm, &dt1, &dt2, &d, &tm, &dto)
if err != nil {
Expand Down Expand Up @@ -153,11 +154,11 @@ select
sql.Named("nv", sin),
sql.Named("v", sin),
sql.Named("tgo", tin),
sql.Named("dt1", DateTime1(tin)),
sql.Named("dt1", mssqltypes.DateTime1(tin)),
sql.Named("dt2", civil.DateTimeOf(tin)),
sql.Named("d", civil.DateOf(tin)),
sql.Named("tm", civil.TimeOf(tin)),
sql.Named("dto", DateTimeOffset(tin)),
sql.Named("dto", mssqltypes.DateTimeOffset(tin)),
).Scan(&nv, &v, &tgo, &dt1, &dt2, &d, &tm, &dto)
if err != nil {
t.Fatal(err)
Expand Down