Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

internal/transport: convert ConnectionError to Unavailable status when writing headers #6891

Merged
merged 3 commits into from Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion internal/transport/http2_server.go
Expand Up @@ -960,7 +960,12 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
}
}
if err := t.writeHeaderLocked(s); err != nil {
return status.Convert(err).Err()
switch e := err.(type) {
case ConnectionError:
return status.Error(codes.Unavailable, e.Desc)
default:
return status.Convert(err).Err()
}
}
return nil
}
Expand Down
65 changes: 65 additions & 0 deletions internal/transport/transport_test.go
Expand Up @@ -45,6 +45,7 @@ import (
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
)
Expand Down Expand Up @@ -2136,6 +2137,70 @@ func (s) TestHeadersHTTPStatusGRPCStatus(t *testing.T) {
}
}

func (s) TestWriteHeaderConnectionError(t *testing.T) {
server, client, cancel := setUp(t, 0, notifyCall)
defer cancel()
defer server.stop()

waitWhileTrue(t, func() (bool, error) {
server.mu.Lock()
defer server.mu.Unlock()

if len(server.conns) == 0 {
return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
}
return false, nil
})

server.mu.Lock()

if len(server.conns) != 1 {
t.Fatalf("Server has %d connections from the client, want 1", len(server.conns))
}

// Get the server transport for the connecton to the client.
var serverTransport *http2Server
for k := range server.conns {
serverTransport = k.(*http2Server)
}
notifyChan := make(chan struct{})
server.h.notify = notifyChan
server.mu.Unlock()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cstream, err := client.NewStream(ctx, &CallHdr{})
if err != nil {
t.Fatalf("Client failed to create first stream. Err: %v", err)
}

<-notifyChan // Wait for server stream to be established.
var sstream *Stream
// Access stream on the server.
serverTransport.mu.Lock()
for _, v := range serverTransport.activeStreams {
if v.id == cstream.id {
sstream = v
}
}
serverTransport.mu.Unlock()
if sstream == nil {
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream.id)
}

client.Close(fmt.Errorf("closed manually by test"))

// Wait for server transport to be closed.
<-serverTransport.done

// Write header on a closed server transport.
err = serverTransport.WriteHeader(sstream, metadata.MD{})
st := status.Convert(err)
if st.Code() != codes.Unavailable {
t.Fatalf("WriteHeader() failed with status code %s, want %s", st.Code(), codes.Unavailable)
}
}

func (s) TestPingPong1B(t *testing.T) {
runPingPongTest(t, 1)
}
Expand Down