Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ordering of message queue messages #723

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
52 changes: 39 additions & 13 deletions mssql.go
Expand Up @@ -1074,7 +1074,6 @@ type Rowsq struct {
stmt *Stmt
cols []columnStruct
reader *tokenProcessor
nextCols []columnStruct
cancel func()
requestDone bool
inResultSet bool
Expand Down Expand Up @@ -1102,8 +1101,11 @@ func (rc *Rowsq) Close() error {
}
}

// data/sql calls Columns during the app's call to Next
// ProcessSingleResponse queues MsgNext for every columns token.
// data/sql calls Columns during the app's call to Next.
func (rc *Rowsq) Columns() (res []string) {
// r.cols is nil if the first query in a batch is a SELECT or similar query that returns a rowset.
// if will be non-nil for subsequent queries where NextResultSet() has populated it
if rc.cols == nil {
scan:
for {
Expand Down Expand Up @@ -1145,6 +1147,10 @@ func (rc *Rowsq) Next(dest []driver.Value) error {
if tok == nil {
return io.EOF
} else {
switch tokdata := tok.(type) {
case doneInProcStruct:
tok = (doneStruct)(tokdata)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a more idiomatic way to combine the handling of doneInProcStruct and doneStruct?

}
switch tokdata := tok.(type) {
case []interface{}:
for i := range dest {
Expand Down Expand Up @@ -1172,9 +1178,11 @@ func (rc *Rowsq) Next(dest []driver.Value) error {
if rc.reader.outs.returnStatus != nil {
*rc.reader.outs.returnStatus = tokdata
}
case ServerError:
rc.requestDone = true
return tokdata
}
}

} else {
return rc.stmt.c.checkBadConn(rc.reader.ctx, err, false)
}
Expand All @@ -1187,15 +1195,14 @@ func (rc *Rowsq) HasNextResultSet() bool {
return !rc.requestDone
}

// Scans to the next set of columns in the stream
// Scans to the end of the current statement being processed
// Note that the caller may not have read all the rows in the prior set
func (rc *Rowsq) NextResultSet() error {
if rc.requestDone {
return io.EOF
}
scan:
for {
// we should have a columns token in the channel if we aren't at the end
tok, err := rc.reader.nextToken()
if rc.reader.sess.logFlags&logDebug != 0 {
rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("NextResultSet() token type:%v", reflect.TypeOf(tok)))
Expand All @@ -1208,23 +1215,42 @@ scan:
return io.EOF
}
switch tokdata := tok.(type) {
case doneInProcStruct:
tok = (doneStruct)(tokdata)
}
// ProcessSingleResponse queues a MsgNextResult for every "done" and "server error" token
// The only tokens to consume after a "done" should be "done", "server error", or "columns"
switch tokdata := tok.(type) {
case []columnStruct:
rc.nextCols = tokdata
rc.cols = tokdata
rc.inResultSet = true
break scan
case doneStruct:
if tokdata.Status&doneMore == 0 {
rc.nextCols = nil
rc.requestDone = true
break scan
}
if tokdata.isError() {
e := rc.stmt.c.checkBadConn(rc.reader.ctx, tokdata.getError(), false)
switch e.(type) {
case Error:
// Ignore non-fatal server errors. Fatal errors are of type ServerError
default:
return e
}
}
rc.inResultSet = false
rc.cols = nil
break scan
case ReturnStatus:
if rc.reader.outs.returnStatus != nil {
*rc.reader.outs.returnStatus = tokdata
}
case ServerError:
rc.requestDone = true
return tokdata
}
}
rc.cols = rc.nextCols
rc.nextCols = nil
if rc.cols == nil {
return io.EOF
}

return nil
}

Expand Down
143 changes: 137 additions & 6 deletions queries_go19_test.go
@@ -1,3 +1,4 @@
//go:build go1.9
// +build go1.9

package mssql
Expand Down Expand Up @@ -1126,17 +1127,19 @@ func TestMessageQueue(t *testing.T) {

msgs := []interface{}{
sqlexp.MsgNotice{Message: "msg1"},
sqlexp.MsgNextResultSet{},
sqlexp.MsgNext{},
sqlexp.MsgRowsAffected{Count: 1},
sqlexp.MsgNextResultSet{},
sqlexp.MsgNotice{Message: "msg2"},
sqlexp.MsgNextResultSet{},
sqlexp.MsgNextResultSet{},
}
i := 0
rsCount := 0
for active {
msg := retmsg.Message(ctx)
if i >= len(msgs) {
t.Fatalf("Got extra message:%+v", msg)
t.Fatalf("Got extra message:%+v", reflect.TypeOf(msg))
}
t.Log(reflect.TypeOf(msg))
if reflect.TypeOf(msgs[i]) != reflect.TypeOf(msg) {
Expand All @@ -1147,10 +1150,6 @@ func TestMessageQueue(t *testing.T) {
t.Log(m.Message)
case sqlexp.MsgNextResultSet:
active = rows.NextResultSet()
if active {
t.Fatal("NextResultSet returned true")
}
rsCount++
case sqlexp.MsgNext:
if !rows.Next() {
t.Fatal("rows.Next() returned false")
Expand Down Expand Up @@ -1368,3 +1367,135 @@ func TestCancelWithNoResults(t *testing.T) {
t.Fatalf("Unexpected error: %v", r.Err())
}
}

const DropSprocWithCursor = `IF EXISTS (SELECT * FROM sys.objects WHERE object_id = OBJECT_ID(N'[dbo].[TestSqlCmd]') AND type in (N'P', N'PC'))
DROP PROCEDURE [dbo].[TestSqlCmd]
`

// This query generates half a dozen tokenDoneInProc tokens which fill the channel if the app isn't scanning Rowsq
const CreateSprocWithCursor = `
CREATE PROCEDURE [dbo].[TestSqlCmd]
AS
BEGIN
DECLARE @tmp int;
DECLARE Server_Cursor CURSOR FOR
SELECT 1 UNION SELECT 2
OPEN Server_Cursor;
FETCH NEXT FROM Server_Cursor INTO @tmp;
WHILE @@FETCH_STATUS = 0
BEGIN
PRINT @tmp
FETCH NEXT FROM Server_Cursor INTO @tmp;
END;
CLOSE Server_Cursor;
DEALLOCATE Server_Cursor;
END
`

func TestSprocWithCursorNoResult(t *testing.T) {
conn, logger := open(t)
defer conn.Close()
defer logger.StopLogging()

_, e := conn.Exec(DropSprocWithCursor)
if e != nil {
t.Fatalf("Unable to drop test sproc: %v", e)
}
_, e = conn.Exec(CreateSprocWithCursor)
if e != nil {
t.Fatalf("Unable to create test sproc: %v", e)
}
defer conn.Exec(DropSprocWithCursor)
latency, _ := getLatency(t)
ctx, cancel := context.WithTimeout(context.Background(), latency+500*time.Millisecond)
defer cancel()
retmsg := &sqlexp.ReturnMessage{}
// Use a sproc instead of the cursor loop directly to cover the different code path in token.go
r, err := conn.QueryContext(ctx, `exec [dbo].[TestSqlCmd]`, retmsg)
if err != nil {
t.Fatal(err.Error())
}
defer r.Close()
active := true
rsCount := 0
msgCount := 0
for active {
msg := retmsg.Message(ctx)
t.Logf("Got a message: %v", reflect.TypeOf(msg))
switch m := msg.(type) {
case sqlexp.MsgNext:
t.Fatalf("Got a MsgNext from a query with no rows")
case sqlexp.MsgError:
t.Fatalf("Got an error: %s", m.Error.Error())
case sqlexp.MsgNotice:
msgCount++
case sqlexp.MsgNextResultSet:
if active = r.NextResultSet(); active {
rsCount++
}
}
}
if r.Err() != nil {
t.Fatalf("Got an error: %v", r.Err())
}
if rsCount != 13 {
t.Fatalf("Unexpected record set count: %v", rsCount)
}
if msgCount != 2 {
t.Fatalf("Unexpected message count: %v", msgCount)
}
}

func TestErrorAsLastResult(t *testing.T) {
conn, logger := open(t)
defer conn.Close()
defer logger.StopLogging()
latency, _ := getLatency(t)
ctx, cancel := context.WithTimeout(context.Background(), latency+5000*time.Millisecond)
defer cancel()
retmsg := &sqlexp.ReturnMessage{}
// Use a sproc instead of the cursor loop directly to cover the different code path in token.go
r, err := conn.QueryContext(ctx,
`
Print N'message'
select 1
raiserror(N'Error!', 16, 1)`,
retmsg)
if err != nil {
t.Fatal(err.Error())
}
defer r.Close()
active := true
d := 0
err = nil
for active {
msg := retmsg.Message(ctx)
t.Logf("Got a message: %s", reflect.TypeOf(msg))
switch m := msg.(type) {
case sqlexp.MsgNext:
if !r.Next() {
t.Fatalf("Next returned false")
}
r.Scan(&d)
if r.Next() {
t.Fatal("Second Next returned true")
}
case sqlexp.MsgError:
err = m.Error
case sqlexp.MsgNextResultSet:
active = r.NextResultSet()
}
}
if err == nil {
t.Fatal("Should have gotten an error message")
} else {
switch e := err.(type) {
case Error:
if e.Message != "Error!" || e.Class != 16 {
t.Fatalf("Got the wrong mssql error %v", e)
}
default:
t.Fatalf("Got an unexpected error %v", e)
}
}
}
21 changes: 13 additions & 8 deletions token.go
Expand Up @@ -645,7 +645,6 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) {
}

func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs outputs) {
firstResult := true
defer func() {
if err := recover(); err != nil {
if sess.logFlags&logErrors != 0 {
Expand Down Expand Up @@ -704,11 +703,16 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)})
}
}
if outs.msgq != nil {
// For now we ignore ctx->Done errors that ReturnMessageEnqueue might return
// It's not clear how to handle them correctly here, and data/sql seems
// to set Rows.Err correctly when ctx expires already
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
if done.Status&doneMore == 0 {
// Rows marks the request as done when seeing this done token. We queue another result set message
// so the app calls NextResultSet again which will return false.
if outs.msgq != nil {
// For now we ignore ctx->Done errors that ReturnMessageEnqueue might return
// It's not clear how to handle them correctly here, and data/sql seems
// to set Rows.Err correctly when ctx expires already
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
return
Expand Down Expand Up @@ -738,7 +742,12 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)})
}
}
if outs.msgq != nil {
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
if done.Status&doneMore == 0 {
// Rows marks the request as done when seeing this done token. We queue another result set message
// so the app calls NextResultSet again which will return false.
if outs.msgq != nil {
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
Expand All @@ -749,12 +758,8 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
ch <- columns

if outs.msgq != nil {
if !firstResult {
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{})
}
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNext{})
}
firstResult = false

case tokenRow:
row := make([]interface{}, len(columns))
Expand Down