diff --git a/pom.xml b/pom.xml index 7e2db780..9b4f26ac 100644 --- a/pom.xml +++ b/pom.xml @@ -130,6 +130,11 @@ scram-client ${scram-client.version} + + com.ongres.scram + scram-common + ${scram-client.version} + io.projectreactor reactor-core diff --git a/src/main/java/io/r2dbc/postgresql/SingleHostConnectionFunction.java b/src/main/java/io/r2dbc/postgresql/SingleHostConnectionFunction.java index e3383c18..31431753 100644 --- a/src/main/java/io/r2dbc/postgresql/SingleHostConnectionFunction.java +++ b/src/main/java/io/r2dbc/postgresql/SingleHostConnectionFunction.java @@ -20,6 +20,7 @@ import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler; import io.r2dbc.postgresql.authentication.SASLAuthenticationHandler; import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.ConnectionContext; import io.r2dbc.postgresql.client.ConnectionSettings; import io.r2dbc.postgresql.client.PostgresStartupParameterProvider; import io.r2dbc.postgresql.client.StartupMessageFlow; @@ -46,7 +47,7 @@ public Mono connect(SocketAddress endpoint, ConnectionSettings settings) return this.upstreamFunction.connect(endpoint, settings) .delayUntil(client -> getCredentials().flatMapMany(credentials -> StartupMessageFlow - .exchange(auth -> getAuthenticationHandler(auth, credentials), client, this.configuration.getDatabase(), credentials.getUsername(), + .exchange(auth -> getAuthenticationHandler(auth, credentials, client.getContext()), client, this.configuration.getDatabase(), credentials.getUsername(), getParameterProvider(this.configuration, settings))) .handle(ExceptionFactory.INSTANCE::handleErrorResponse)); } @@ -55,13 +56,13 @@ private static PostgresStartupParameterProvider getParameterProvider(PostgresqlC return new PostgresStartupParameterProvider(configuration.getApplicationName(), configuration.getTimeZone(), settings); } - protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message, UsernameAndPassword usernameAndPassword) { + protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message, UsernameAndPassword usernameAndPassword, ConnectionContext context) { if (PasswordAuthenticationHandler.supports(message)) { CharSequence password = Assert.requireNonNull(usernameAndPassword.getPassword(), "Password must not be null"); return new PasswordAuthenticationHandler(password, usernameAndPassword.getUsername()); } else if (SASLAuthenticationHandler.supports(message)) { CharSequence password = Assert.requireNonNull(usernameAndPassword.getPassword(), "Password must not be null"); - return new SASLAuthenticationHandler(password, usernameAndPassword.getUsername()); + return new SASLAuthenticationHandler(password, usernameAndPassword.getUsername(), context); } else { throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message)); } diff --git a/src/main/java/io/r2dbc/postgresql/authentication/SASLAuthenticationHandler.java b/src/main/java/io/r2dbc/postgresql/authentication/SASLAuthenticationHandler.java index c38c6bac..0ca21ac9 100644 --- a/src/main/java/io/r2dbc/postgresql/authentication/SASLAuthenticationHandler.java +++ b/src/main/java/io/r2dbc/postgresql/authentication/SASLAuthenticationHandler.java @@ -3,7 +3,8 @@ import com.ongres.scram.client.ScramClient; import com.ongres.scram.common.StringPreparation; import com.ongres.scram.common.exception.ScramException; - +import com.ongres.scram.common.util.TlsServerEndpoint; +import io.r2dbc.postgresql.client.ConnectionContext; import io.r2dbc.postgresql.message.backend.AuthenticationMessage; import io.r2dbc.postgresql.message.backend.AuthenticationSASL; import io.r2dbc.postgresql.message.backend.AuthenticationSASLContinue; @@ -14,14 +15,26 @@ import io.r2dbc.postgresql.util.Assert; import io.r2dbc.postgresql.util.ByteBufferUtils; import reactor.core.Exceptions; +import reactor.util.Logger; +import reactor.util.Loggers; import reactor.util.annotation.Nullable; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSession; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; + public class SASLAuthenticationHandler implements AuthenticationHandler { + private static final Logger LOG = Loggers.getLogger(SASLAuthenticationHandler.class); + private final CharSequence password; private final String username; + private final ConnectionContext context; + private ScramClient scramClient; /** @@ -29,11 +42,13 @@ public class SASLAuthenticationHandler implements AuthenticationHandler { * * @param password the password to use for authentication * @param username the username to use for authentication + * @param context the connection context * @throws IllegalArgumentException if {@code password} or {@code user} is {@code null} */ - public SASLAuthenticationHandler(CharSequence password, String username) { + public SASLAuthenticationHandler(CharSequence password, String username, ConnectionContext context) { this.password = Assert.requireNonNull(password, "password must not be null"); this.username = Assert.requireNonNull(username, "username must not be null"); + this.context = Assert.requireNonNull(context, "context must not be null"); } /** @@ -67,14 +82,44 @@ public FrontendMessage handle(AuthenticationMessage message) { } private FrontendMessage handleAuthenticationSASL(AuthenticationSASL message) { - this.scramClient = ScramClient.builder() + + char[] password = new char[this.password.length()]; + for (int i = 0; i < password.length; i++) { + password[i] = this.password.charAt(i); + } + + ScramClient.FinalBuildStage builder = ScramClient.builder() .advertisedMechanisms(message.getAuthenticationMechanisms()) - .username(username) // ignored by the server, use startup message - .password(password.toString().toCharArray()) - .stringPreparation(StringPreparation.POSTGRESQL_PREPARATION) - .build(); + .username(this.username) // ignored by the server, use startup message + .password(password) + .stringPreparation(StringPreparation.POSTGRESQL_PREPARATION); + + SSLSession sslSession = this.context.getSslSession(); - return new SASLInitialResponse(ByteBufferUtils.encode(this.scramClient.clientFirstMessage().toString()), scramClient.getScramMechanism().getName()); + if (sslSession != null && sslSession.isValid()) { + builder.channelBinding(TlsServerEndpoint.TLS_SERVER_END_POINT, extractSslEndpoint(sslSession)); + } + + this.scramClient = builder.build(); + + return new SASLInitialResponse(ByteBufferUtils.encode(this.scramClient.clientFirstMessage().toString()), this.scramClient.getScramMechanism().getName()); + } + + private static byte[] extractSslEndpoint(SSLSession sslSession) { + try { + Certificate[] certificates = sslSession.getPeerCertificates(); + if (certificates != null && certificates.length > 0) { + Certificate peerCert = certificates[0]; // First certificate is the peer's certificate + if (peerCert instanceof X509Certificate) { + X509Certificate cert = (X509Certificate) peerCert; + return TlsServerEndpoint.getChannelBindingData(cert); + + } + } + } catch (CertificateException | SSLException e) { + LOG.debug("Cannot extract X509Certificate from SSL session", e); + } + return new byte[0]; } private FrontendMessage handleAuthenticationSASLContinue(AuthenticationSASLContinue message) { diff --git a/src/main/java/io/r2dbc/postgresql/client/ConnectionContext.java b/src/main/java/io/r2dbc/postgresql/client/ConnectionContext.java index 90796c33..c348d606 100644 --- a/src/main/java/io/r2dbc/postgresql/client/ConnectionContext.java +++ b/src/main/java/io/r2dbc/postgresql/client/ConnectionContext.java @@ -20,7 +20,9 @@ import reactor.util.Loggers; import javax.annotation.Nullable; +import javax.net.ssl.SSLSession; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; /** * Value object capturing diagnostic connection context. Allows for log-message post-processing with {@link #getMessage(String) if the logger category for @@ -50,6 +52,8 @@ public final class ConnectionContext { private final String connectionIdPrefix; + private final Supplier sslSession; + /** * Create a new {@link ConnectionContext} with a unique connection Id. */ @@ -58,13 +62,15 @@ public ConnectionContext() { this.connectionCounter = incrementConnectionCounter(); this.connectionIdPrefix = getConnectionIdPrefix(); this.channelId = null; + this.sslSession = () -> null; } - private ConnectionContext(@Nullable Integer processId, @Nullable String channelId, String connectionCounter) { + private ConnectionContext(@Nullable Integer processId, @Nullable String channelId, String connectionCounter, Supplier sslSession) { this.processId = processId; this.channelId = channelId; this.connectionCounter = connectionCounter; this.connectionIdPrefix = getConnectionIdPrefix(); + this.sslSession = sslSession; } private String incrementConnectionCounter() { @@ -101,6 +107,11 @@ public String getMessage(String original) { return original; } + @Nullable + public SSLSession getSslSession() { + return this.sslSession.get(); + } + /** * Create a new {@link ConnectionContext} by associating the {@code channelId}. * @@ -108,7 +119,17 @@ public String getMessage(String original) { * @return a new {@link ConnectionContext} with all previously set values and the associated {@code channelId}. */ public ConnectionContext withChannelId(String channelId) { - return new ConnectionContext(this.processId, channelId, this.connectionCounter); + return new ConnectionContext(this.processId, channelId, this.connectionCounter, this.sslSession); + } + + /** + * Create a new {@link ConnectionContext} by associating the {@code sslSession}. + * + * @param sslSession the SSL session supplier. + * @return a new {@link ConnectionContext} with all previously set values and the associated {@code sslSession}. + */ + public ConnectionContext withSslSession(Supplier sslSession) { + return new ConnectionContext(this.processId, this.channelId, this.connectionCounter, sslSession); } /** @@ -118,7 +139,7 @@ public ConnectionContext withChannelId(String channelId) { * @return a new {@link ConnectionContext} with all previously set values and the associated {@code processId}. */ public ConnectionContext withProcessId(int processId) { - return new ConnectionContext(processId, this.channelId, this.connectionCounter); + return new ConnectionContext(processId, this.channelId, this.connectionCounter, this.sslSession); } } diff --git a/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java b/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java index 1effe894..8ff9cf30 100644 --- a/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java +++ b/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java @@ -25,6 +25,7 @@ import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import io.netty.handler.logging.LogLevel; import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.ssl.SslHandler; import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -148,7 +149,23 @@ private ReactorNettyClient(Connection connection, ConnectionSettings settings) { connection.addHandlerLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE - 5, 1, 4, -4, 0)); this.connection = connection; this.byteBufAllocator = connection.outbound().alloc(); - this.context = new ConnectionContext().withChannelId(connection.channel().toString()); + + ConnectionContext connectionContext = new ConnectionContext().withChannelId(connection.channel().toString()); + SslHandler sslHandler = this.connection.channel().pipeline().get(SslHandler.class); + + if (sslHandler == null) { + SSLSessionHandlerAdapter handlerAdapter = this.connection.channel().pipeline().get(SSLSessionHandlerAdapter.class); + if (handlerAdapter != null) { + sslHandler = handlerAdapter.getSslHandler(); + } + } + + if (sslHandler != null) { + SslHandler toUse = sslHandler; + connectionContext = connectionContext.withSslSession(() -> toUse.engine().getSession()); + } + + this.context = connectionContext; AtomicReference receiveError = new AtomicReference<>(); diff --git a/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java b/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java index 9247cd62..616f2da2 100644 --- a/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java +++ b/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java @@ -45,7 +45,7 @@ final class SSLSessionHandlerAdapter extends AbstractPostgresSSLHandlerAdapter { @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { - if (negotiating) { + if (this.negotiating) { Mono.from(SSLRequest.INSTANCE.encode(this.alloc)).subscribe(ctx::writeAndFlush); } super.channelActive(ctx); @@ -53,7 +53,7 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception { @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { - if (negotiating) { + if (this.negotiating) { // If we receive channel inactive before negotiated, then the inbound has closed early. PostgresqlSslException e = new PostgresqlSslException("Connection closed during SSL negotiation"); completeHandshakeExceptionally(e); @@ -63,7 +63,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - if (negotiating) { + if (this.negotiating) { ByteBuf buf = (ByteBuf) msg; char response = (char) buf.readByte(); try { @@ -79,7 +79,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception } } finally { buf.release(); - negotiating = false; + this.negotiating = false; } } else { super.channelRead(ctx, msg); diff --git a/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java b/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java index ecb94ef0..c03aba09 100644 --- a/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java +++ b/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java @@ -451,6 +451,21 @@ void exchangeSslWithClientCertNoCert() { .expectError(R2dbcPermissionDeniedException.class)); } + @Test + void exchangeSslWitScram() { + client( + c -> c + .sslRootCert(SERVER.getServerCrt()) + .username("test-ssl-scram") + .password("test-ssl-scram"), + c -> c.map(client -> client.createStatement("SELECT 10") + .execute() + .flatMap(r -> r.map((row, meta) -> row.get(0, Integer.class))) + .as(StepVerifier::create) + .expectNext(10) + .verifyComplete())); + } + @Test void exchangeSslWithPassword() { client( diff --git a/src/test/resources/pg_hba.conf b/src/test/resources/pg_hba.conf index 6acee8b0..6330483e 100644 --- a/src/test/resources/pg_hba.conf +++ b/src/test/resources/pg_hba.conf @@ -1,5 +1,6 @@ hostnossl all test all md5 hostnossl all test-scram all scram-sha-256 hostssl all test-ssl all password +hostssl all test-ssl-scram all scram-sha-256 hostssl all test-ssl-with-cert all cert local all all md5