Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore spurious "rowcount" data from the server #735

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
33 changes: 33 additions & 0 deletions .github/workflows/pr-validation.yml
@@ -0,0 +1,33 @@
name: pr-validation

on:
pull_request:
branches:
- main

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
go: ['1.16','1.17', '1.18']
sqlImage: ['2017-latest','2019-latest']
steps:
- uses: actions/checkout@v2
- name: Setup go
uses: actions/setup-go@v2
with:
go-version: '${{ matrix.go }}'
- name: Run tests against Linux SQL
run: |
go version
go get -d
export SQLCMDPASSWORD=$(date +%s|sha256sum|base64|head -c 32)
export SQLCMDUSER=sa
export SQLUSER=sa
export SQLPASSWORD=$SQLCMDPASSWORD
export DATABASE=test
docker run -m 2GB -e ACCEPT_EULA=1 -d --name sqlserver -p:1433:1433 -e SA_PASSWORD=$SQLCMDPASSWORD mcr.microsoft.com/mssql/server:${{ matrix.sqlImage }}
sleep 10
sqlcmd -Q "CREATE DATABASE test"
go test -race -cpu 4 ./...
52 changes: 39 additions & 13 deletions mssql.go
Expand Up @@ -1074,7 +1074,6 @@ type Rowsq struct {
stmt *Stmt
cols []columnStruct
reader *tokenProcessor
nextCols []columnStruct
cancel func()
requestDone bool
inResultSet bool
Expand Down Expand Up @@ -1102,8 +1101,11 @@ func (rc *Rowsq) Close() error {
}
}

// data/sql calls Columns during the app's call to Next
// ProcessSingleResponse queues MsgNext for every columns token.
// data/sql calls Columns during the app's call to Next.
func (rc *Rowsq) Columns() (res []string) {
// r.cols is nil if the first query in a batch is a SELECT or similar query that returns a rowset.
// if will be non-nil for subsequent queries where NextResultSet() has populated it
if rc.cols == nil {
scan:
for {
Expand Down Expand Up @@ -1145,6 +1147,10 @@ func (rc *Rowsq) Next(dest []driver.Value) error {
if tok == nil {
return io.EOF
} else {
switch tokdata := tok.(type) {
case doneInProcStruct:
tok = (doneStruct)(tokdata)
}
switch tokdata := tok.(type) {
case []interface{}:
for i := range dest {
Expand Down Expand Up @@ -1172,9 +1178,11 @@ func (rc *Rowsq) Next(dest []driver.Value) error {
if rc.reader.outs.returnStatus != nil {
*rc.reader.outs.returnStatus = tokdata
}
case ServerError:
rc.requestDone = true
return tokdata
}
}

} else {
return rc.stmt.c.checkBadConn(rc.reader.ctx, err, false)
}
Expand All @@ -1187,15 +1195,14 @@ func (rc *Rowsq) HasNextResultSet() bool {
return !rc.requestDone
}

// Scans to the next set of columns in the stream
// Scans to the end of the current statement being processed
// Note that the caller may not have read all the rows in the prior set
func (rc *Rowsq) NextResultSet() error {
if rc.requestDone {
return io.EOF
}
scan:
for {
// we should have a columns token in the channel if we aren't at the end
tok, err := rc.reader.nextToken()
if rc.reader.sess.logFlags&logDebug != 0 {
rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("NextResultSet() token type:%v", reflect.TypeOf(tok)))
Expand All @@ -1208,23 +1215,42 @@ scan:
return io.EOF
}
switch tokdata := tok.(type) {
case doneInProcStruct:
tok = (doneStruct)(tokdata)
}
// ProcessSingleResponse queues a MsgNextResult for every "done" and "server error" token
// The only tokens to consume after a "done" should be "done", "server error", or "columns"
switch tokdata := tok.(type) {
case []columnStruct:
rc.nextCols = tokdata
rc.cols = tokdata
rc.inResultSet = true
break scan
case doneStruct:
if tokdata.Status&doneMore == 0 {
rc.nextCols = nil
rc.requestDone = true
break scan
}
if tokdata.isError() {
e := rc.stmt.c.checkBadConn(rc.reader.ctx, tokdata.getError(), false)
switch e.(type) {
case Error:
// Ignore non-fatal server errors. Fatal errors are of type ServerError
default:
return e
}
}
rc.inResultSet = false
rc.cols = nil
break scan
case ReturnStatus:
if rc.reader.outs.returnStatus != nil {
*rc.reader.outs.returnStatus = tokdata
}
case ServerError:
rc.requestDone = true
return tokdata
}
}
rc.cols = rc.nextCols
rc.nextCols = nil
if rc.cols == nil {
return io.EOF
}

return nil
}

Expand Down
146 changes: 139 additions & 7 deletions queries_go19_test.go
@@ -1,3 +1,4 @@
//go:build go1.9
// +build go1.9

package mssql
Expand Down Expand Up @@ -1126,17 +1127,19 @@ func TestMessageQueue(t *testing.T) {

msgs := []interface{}{
sqlexp.MsgNotice{Message: "msg1"},
sqlexp.MsgNextResultSet{},
sqlexp.MsgNext{},
sqlexp.MsgRowsAffected{Count: 1},
sqlexp.MsgNextResultSet{},
sqlexp.MsgNotice{Message: "msg2"},
sqlexp.MsgNextResultSet{},
sqlexp.MsgNextResultSet{},
}
i := 0
rsCount := 0
for active {
msg := retmsg.Message(ctx)
if i >= len(msgs) {
t.Fatalf("Got extra message:%+v", msg)
t.Fatalf("Got extra message:%+v", reflect.TypeOf(msg))
}
t.Log(reflect.TypeOf(msg))
if reflect.TypeOf(msgs[i]) != reflect.TypeOf(msg) {
Expand All @@ -1147,10 +1150,6 @@ func TestMessageQueue(t *testing.T) {
t.Log(m.Message)
case sqlexp.MsgNextResultSet:
active = rows.NextResultSet()
if active {
t.Fatal("NextResultSet returned true")
}
rsCount++
case sqlexp.MsgNext:
if !rows.Next() {
t.Fatal("rows.Next() returned false")
Expand Down Expand Up @@ -1237,7 +1236,8 @@ select getdate()
PRINT N'This is a message'
select 199
RAISERROR (N'Testing!' , 11, 1)
select 300
declare @d int = 300
select @d
`

func testMixedQuery(conn *sql.DB, b testing.TB) (msgs, errs, results, rowcounts int) {
Expand Down Expand Up @@ -1368,3 +1368,135 @@ func TestCancelWithNoResults(t *testing.T) {
t.Fatalf("Unexpected error: %v", r.Err())
}
}

const DropSprocWithCursor = `IF EXISTS (SELECT * FROM sys.objects WHERE object_id = OBJECT_ID(N'[dbo].[TestSqlCmd]') AND type in (N'P', N'PC'))
DROP PROCEDURE [dbo].[TestSqlCmd]
`

// This query generates half a dozen tokenDoneInProc tokens which fill the channel if the app isn't scanning Rowsq
const CreateSprocWithCursor = `
CREATE PROCEDURE [dbo].[TestSqlCmd]
AS
BEGIN
DECLARE @tmp int;
DECLARE Server_Cursor CURSOR FOR
SELECT 1 UNION SELECT 2
OPEN Server_Cursor;
FETCH NEXT FROM Server_Cursor INTO @tmp;
WHILE @@FETCH_STATUS = 0
BEGIN
PRINT @tmp
FETCH NEXT FROM Server_Cursor INTO @tmp;
END;
CLOSE Server_Cursor;
DEALLOCATE Server_Cursor;
END
`

func TestSprocWithCursorNoResult(t *testing.T) {
conn, logger := open(t)
defer conn.Close()
defer logger.StopLogging()

_, e := conn.Exec(DropSprocWithCursor)
if e != nil {
t.Fatalf("Unable to drop test sproc: %v", e)
}
_, e = conn.Exec(CreateSprocWithCursor)
if e != nil {
t.Fatalf("Unable to create test sproc: %v", e)
}
defer conn.Exec(DropSprocWithCursor)
latency, _ := getLatency(t)
ctx, cancel := context.WithTimeout(context.Background(), latency+500*time.Millisecond)
defer cancel()
retmsg := &sqlexp.ReturnMessage{}
// Use a sproc instead of the cursor loop directly to cover the different code path in token.go
r, err := conn.QueryContext(ctx, `exec [dbo].[TestSqlCmd]`, retmsg)
if err != nil {
t.Fatal(err.Error())
}
defer r.Close()
active := true
rsCount := 0
msgCount := 0
for active {
msg := retmsg.Message(ctx)
t.Logf("Got a message: %v", reflect.TypeOf(msg))
switch m := msg.(type) {
case sqlexp.MsgNext:
t.Fatalf("Got a MsgNext from a query with no rows")
case sqlexp.MsgError:
t.Fatalf("Got an error: %s", m.Error.Error())
case sqlexp.MsgNotice:
msgCount++
case sqlexp.MsgNextResultSet:
if active = r.NextResultSet(); active {
rsCount++
}
}
}
if r.Err() != nil {
t.Fatalf("Got an error: %v", r.Err())
}
if rsCount != 13 {
t.Fatalf("Unexpected record set count: %v", rsCount)
}
if msgCount != 2 {
t.Fatalf("Unexpected message count: %v", msgCount)
}
}

func TestErrorAsLastResult(t *testing.T) {
conn, logger := open(t)
defer conn.Close()
defer logger.StopLogging()
latency, _ := getLatency(t)
ctx, cancel := context.WithTimeout(context.Background(), latency+5000*time.Millisecond)
defer cancel()
retmsg := &sqlexp.ReturnMessage{}
// Use a sproc instead of the cursor loop directly to cover the different code path in token.go
r, err := conn.QueryContext(ctx,
`
Print N'message'
select 1
raiserror(N'Error!', 16, 1)`,
retmsg)
if err != nil {
t.Fatal(err.Error())
}
defer r.Close()
active := true
d := 0
err = nil
for active {
msg := retmsg.Message(ctx)
t.Logf("Got a message: %s", reflect.TypeOf(msg))
switch m := msg.(type) {
case sqlexp.MsgNext:
if !r.Next() {
t.Fatalf("Next returned false")
}
r.Scan(&d)
if r.Next() {
t.Fatal("Second Next returned true")
}
case sqlexp.MsgError:
err = m.Error
case sqlexp.MsgNextResultSet:
active = r.NextResultSet()
}
}
if err == nil {
t.Fatal("Should have gotten an error message")
} else {
switch e := err.(type) {
case Error:
if e.Message != "Error!" || e.Class != 16 {
t.Fatalf("Got the wrong mssql error %v", e)
}
default:
t.Fatalf("Got an unexpected error %v", e)
}
}
}