diff --git a/README.md b/README.md index bc9372d9..33d3b765 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,7 @@ Optional)_ | `sslCert` | Path to SSL certificate for TLS authentication in PEM format. Can be also a resource path. _(Optional)_ | `sslPassword` | Key password to decrypt SSL key. _(Optional)_ | `sslHostnameVerifier` | `javax.net.ssl.HostnameVerifier` implementation. _(Optional)_ +| `sslSni` | Enable/disable SNI to send the configured `host` name during the SSL handshake. Defaults to `true`. _(Optional)_ | `statementTimeout` | Statement timeout. _(Optional)_ | `targetServerType` | Type of server to use when using multi-host operations. Supported values: `ANY`, `PRIMARY`, `SECONDARY`, `PREFER_SECONDARY`. Defaults to `ANY`. _(Optional)_ | `tcpNoDelay` | Enable/disable TCP NoDelay. Enabled by default. _(Optional)_ diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java index 6e5720aa..bda60aab 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java @@ -42,12 +42,18 @@ import reactor.util.annotation.Nullable; import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; +import javax.net.ssl.SSLParameters; import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.net.InetSocketAddress; import java.net.MalformedURLException; import java.net.Socket; +import java.net.SocketAddress; import java.net.URL; import java.time.Duration; import java.util.ArrayList; @@ -411,6 +417,12 @@ public static final class Builder { private Function sslContextBuilderCustomizer = Function.identity(); + private Function sslEngineCustomizer = Function.identity(); + + private Function sslParametersFactory = it -> new SSLParameters(); + + private boolean sslSni = true; + private boolean tcpKeepAlive = false; private boolean tcpNoDelay = true; @@ -476,7 +488,7 @@ public PostgresqlConnectionConfiguration build() { this.extensions, this.fetchSize, this.forceBinary, this.lockWaitTimeout, this.loopResources, multiHostConfiguration, this.noticeLogLevel, this.options, this.password, this.preferAttachedBuffers, this.preparedStatementCacheQueries, this.schema, singleHostConfiguration, - this.createSslConfig(), this.statementTimeout, this.tcpKeepAlive, this.tcpNoDelay, this.timeZone, this.username); + this.createSslConfig(this.sslSni), this.statementTimeout, this.tcpKeepAlive, this.tcpNoDelay, this.timeZone, this.username); } /** @@ -852,6 +864,47 @@ public Builder sslContextBuilderCustomizer(Function sslEngineCustomizer) { + this.sslEngineCustomizer = Assert.requireNonNull(sslEngineCustomizer, "sslEngineCustomizer must not be null"); + return this; + } + + /** + * Configure a {@link SSLParameters} provider for a given {@link SocketAddress}. The provider gets applied on each SSL connection attempt to allow for just-in-time configuration updates. + * Typically used to configure SSL protocols + * + * @param sslParametersFactory customizer function + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code sslParametersFactory} is {@code null} + * @since 1.0.4 + */ + public Builder sslParameters(Function sslParametersFactory) { + this.sslParametersFactory = Assert.requireNonNull(sslParametersFactory, "sslParametersFactory must not be null"); + return this; + } + + /** + * Configure whether to indicate the hostname and port via SNI to the server. Enabled by default. + * + * @param sslSni whether to indicate the hostname and port via SNI. Sets {@link SSLParameters#setServerNames(List)} on the {@link SSLParameters} instance provided by + * {@link #sslParameters(Function)}. + * @return this {@link Builder} + * @since 1.0.4 + */ + public Builder sslSni(boolean sslSni) { + this.sslSni = sslSni; + return this; + } + /** * Configure ssl cert for client certificate authentication. Can point to either a resource within the classpath or a file. * @@ -1094,10 +1147,13 @@ public String toString() { ", schema='" + this.schema + '\'' + ", singleHostConfiguration='" + this.singleHostConfiguration + '\'' + ", sslContextBuilderCustomizer='" + this.sslContextBuilderCustomizer + '\'' + + ", sslEngineCustomizer='" + this.sslEngineCustomizer + '\'' + + ", sslParametersFactory='" + this.sslParametersFactory + '\'' + ", sslMode='" + this.sslMode + '\'' + ", sslRootCert='" + this.sslRootCert + '\'' + ", sslCert='" + this.sslCert + '\'' + ", sslKey='" + this.sslKey + '\'' + + ", sslSni='" + this.sslSni + '\'' + ", statementTimeout='" + this.statementTimeout + '\'' + ", sslHostnameVerifier='" + this.sslHostnameVerifier + '\'' + ", tcpKeepAlive='" + this.tcpKeepAlive + '\'' + @@ -1107,15 +1163,47 @@ public String toString() { '}'; } - private SSLConfig createSslConfig() { + private SSLConfig createSslConfig(boolean sslSni) { if (this.singleHostConfiguration != null && this.singleHostConfiguration.getSocket() != null || this.sslMode == SSLMode.DISABLE) { return SSLConfig.disabled(); } - HostnameVerifier hostnameVerifier = this.sslHostnameVerifier; - return new SSLConfig(this.sslMode, createSslProvider(), hostnameVerifier); + Function sslParametersFunctionToUse = getSslParametersFactory(sslSni, this.sslParametersFactory); + return new SSLConfig(this.sslMode, createSslProvider(), this.sslEngineCustomizer, sslParametersFunctionToUse, this.sslHostnameVerifier); + } + + private static Function getSslParametersFactory(boolean sslSni, Function sslParametersFunction) { + if (!sslSni) { + return sslParametersFunction; + } + + return socket -> { + + SSLParameters sslParameters = sslParametersFunction.apply(socket); + + if (socket instanceof InetSocketAddress) { + + InetSocketAddress inetSocketAddress = (InetSocketAddress) socket; + String hostString = inetSocketAddress.getHostString(); + if (SSLConfig.isValidSniHostname(hostString)) { + appendSniHost(sslParameters, hostString); + } + } + + return sslParameters; + }; } + private static void appendSniHost(SSLParameters sslParameters, String hostString) { + + List existingServerNames = sslParameters.getServerNames(); + List serverNames = existingServerNames == null ? new ArrayList<>() : new ArrayList<>(existingServerNames); + serverNames.add(new SNIHostName(hostString)); + + sslParameters.setServerNames(serverNames); + } + + private Supplier createSslProvider() { SslContextBuilder sslContextBuilder = SslContextBuilder.forClient(); if (this.sslMode.verifyCertificate()) { diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java index b5180fb4..e2feae0a 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java @@ -221,6 +221,13 @@ public final class PostgresqlConnectionFactoryProvider implements ConnectionFact */ public static final Option SSL_ROOT_CERT = Option.valueOf("sslRootCert"); + /** + * Configure whether to use SNI on SSL connections. Enabled by default. + * + * @since 1.0.4 + */ + public static final Option SSL_SNI = Option.valueOf("sslSni"); + /** * Statement timeout. * @@ -406,6 +413,7 @@ private static void setupSsl(PostgresqlConnectionConfiguration.Builder builder, mapper.fromTyped(SSL_KEY).to(builder::sslKey); mapper.fromTyped(SSL_ROOT_CERT).to(builder::sslRootCert); mapper.fromTyped(SSL_PASSWORD).to(builder::sslPassword); + mapper.fromTyped(SSL_SNI).map(OptionMapper::toBoolean).to(builder::sslSni); mapper.from(SSL_HOSTNAME_VERIFIER).map(it -> { diff --git a/src/main/java/io/r2dbc/postgresql/client/AbstractPostgresSSLHandlerAdapter.java b/src/main/java/io/r2dbc/postgresql/client/AbstractPostgresSSLHandlerAdapter.java index 6edc7341..bc7743a6 100644 --- a/src/main/java/io/r2dbc/postgresql/client/AbstractPostgresSSLHandlerAdapter.java +++ b/src/main/java/io/r2dbc/postgresql/client/AbstractPostgresSSLHandlerAdapter.java @@ -28,7 +28,9 @@ import reactor.core.publisher.Mono; import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.util.concurrent.CompletableFuture; abstract class AbstractPostgresSSLHandlerAdapter extends ChannelInboundHandlerAdapter implements GenericFutureListener> { @@ -41,9 +43,14 @@ abstract class AbstractPostgresSSLHandlerAdapter extends ChannelInboundHandlerAd private final CompletableFuture handshakeFuture; - AbstractPostgresSSLHandlerAdapter(ByteBufAllocator alloc, SSLConfig sslConfig) { + AbstractPostgresSSLHandlerAdapter(ByteBufAllocator alloc, SocketAddress socketAddress, SSLConfig sslConfig) { this.sslConfig = sslConfig; - this.sslEngine = sslConfig.getSslProvider().get().newEngine(alloc); + + SSLEngine sslEngine = sslConfig.getSslProvider().get().newEngine(alloc); + SSLParameters sslParameters = sslConfig.getSslParametersFactory().apply(socketAddress); + sslEngine.setSSLParameters(sslParameters); + + this.sslEngine = sslConfig.getSslEngineCustomizer().apply(sslEngine); this.handshakeFuture = new CompletableFuture<>(); this.sslHandler = new SslHandler(this.sslEngine); this.sslHandler.handshakeFuture().addListener(this); diff --git a/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java b/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java index da55dd3c..1effe894 100644 --- a/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java +++ b/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java @@ -399,21 +399,21 @@ public static Mono connect(SocketAddress socketAddress, Conn new LoggingHandler(ReactorNettyClient.class, LogLevel.TRACE)); } - registerSslHandler(settings.getSslConfig(), channel); + registerSslHandler(socketAddress, settings.getSslConfig(), channel); }).connect().flatMap(it -> getSslHandshake(it.channel()).thenReturn(new ReactorNettyClient(it, settings)) ); } - private static void registerSslHandler(SSLConfig sslConfig, Channel channel) { + private static void registerSslHandler(SocketAddress socketAddress, SSLConfig sslConfig, Channel channel) { try { if (sslConfig.getSslMode().startSsl()) { AbstractPostgresSSLHandlerAdapter sslAdapter; if (sslConfig.getSslMode() == SSLMode.TUNNEL) { - sslAdapter = new SSLTunnelHandlerAdapter(channel.alloc(), sslConfig); + sslAdapter = new SSLTunnelHandlerAdapter(channel.alloc(), socketAddress, sslConfig); } else { - sslAdapter = new SSLSessionHandlerAdapter(channel.alloc(), sslConfig); + sslAdapter = new SSLSessionHandlerAdapter(channel.alloc(), socketAddress, sslConfig); } channel.pipeline().addFirst(sslAdapter); diff --git a/src/main/java/io/r2dbc/postgresql/client/SSLConfig.java b/src/main/java/io/r2dbc/postgresql/client/SSLConfig.java index 19203155..d11930d1 100644 --- a/src/main/java/io/r2dbc/postgresql/client/SSLConfig.java +++ b/src/main/java/io/r2dbc/postgresql/client/SSLConfig.java @@ -21,6 +21,10 @@ import reactor.util.annotation.Nullable; import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; +import java.net.SocketAddress; +import java.util.function.Function; import java.util.function.Supplier; public final class SSLConfig { @@ -33,7 +37,16 @@ public final class SSLConfig { @Nullable private final Supplier sslProvider; + private final Function sslEngineCustomizer; + + private final Function sslParametersFactory; + public SSLConfig(SSLMode sslMode, @Nullable Supplier sslProvider, @Nullable HostnameVerifier hostnameVerifier) { + this(sslMode, sslProvider, Function.identity(), it -> new SSLParameters(), hostnameVerifier); + } + + public SSLConfig(SSLMode sslMode, @Nullable Supplier sslProvider, Function sslEngineCustomizer, Function sslParametersFactory, + @Nullable HostnameVerifier hostnameVerifier) { if (sslMode != SSLMode.DISABLE) { Assert.requireNonNull(sslProvider, "SslContext provider is required for ssl mode " + sslMode); } @@ -42,6 +55,8 @@ public SSLConfig(SSLMode sslMode, @Nullable Supplier sslProvider, @N } this.sslMode = sslMode; this.sslProvider = sslProvider; + this.sslEngineCustomizer = sslEngineCustomizer; + this.sslParametersFactory = sslParametersFactory; this.hostnameVerifier = hostnameVerifier; } @@ -64,12 +79,60 @@ public Supplier getSslProvider() { return this.sslProvider; } + public Function getSslEngineCustomizer() { + return this.sslEngineCustomizer; + } + + + public Function getSslParametersFactory() { + return this.sslParametersFactory; + } + public SSLConfig mutateMode(SSLMode newMode) { return new SSLConfig( newMode, this.sslProvider, + this.sslEngineCustomizer, + this.sslParametersFactory, this.hostnameVerifier ); } + public static boolean isValidSniHostname(String input) { + for (int i = 0; i < input.length(); i++) { + char c = input.charAt(i); + if (isLabelSeparator(c)) { + continue; + } + if (isNonLDHAsciiCodePoint(c)) { + return false; + } + } + return true; + } + + // + // LDH stands for "letter/digit/hyphen", with characters restricted to the + // 26-letter Latin alphabet , the digits <0-9>, and the hyphen + // <->. + // Non LDH refers to characters in the ASCII range, but which are not + // letters, digits or the hypen. + // + // non-LDH = 0..0x2C, 0x2E..0x2F, 0x3A..0x40, 0x5B..0x60, 0x7B..0x7F + // + private static boolean isNonLDHAsciiCodePoint(char ch) { + return (0x0000 <= ch && ch <= 0x002C) || + (0x002E <= ch && ch <= 0x002F) || + (0x003A <= ch && ch <= 0x0040) || + (0x005B <= ch && ch <= 0x0060) || + (0x007B <= ch && ch <= 0x007F); + } + + // + // to check if a character is a label separator, i.e. a dot character. + // + private static boolean isLabelSeparator(char c) { + return (c == '.' || c == '\u3002' || c == '\uFF0E' || c == '\uFF61'); + } + } diff --git a/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java b/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java index 4e64df8f..9247cd62 100644 --- a/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java +++ b/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java @@ -22,6 +22,8 @@ import io.r2dbc.postgresql.message.frontend.SSLRequest; import reactor.core.publisher.Mono; +import java.net.SocketAddress; + /** * SSL handler assuming the endpoint a Postgres endpoint following the {@link SSLRequest} flow. * @@ -35,8 +37,8 @@ final class SSLSessionHandlerAdapter extends AbstractPostgresSSLHandlerAdapter { private boolean negotiating = true; - SSLSessionHandlerAdapter(ByteBufAllocator alloc, SSLConfig sslConfig) { - super(alloc, sslConfig); + SSLSessionHandlerAdapter(ByteBufAllocator alloc, SocketAddress socketAddress, SSLConfig sslConfig) { + super(alloc, socketAddress, sslConfig); this.alloc = alloc; this.sslConfig = sslConfig; } diff --git a/src/main/java/io/r2dbc/postgresql/client/SSLTunnelHandlerAdapter.java b/src/main/java/io/r2dbc/postgresql/client/SSLTunnelHandlerAdapter.java index 059d66b3..22accabf 100644 --- a/src/main/java/io/r2dbc/postgresql/client/SSLTunnelHandlerAdapter.java +++ b/src/main/java/io/r2dbc/postgresql/client/SSLTunnelHandlerAdapter.java @@ -19,6 +19,8 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.channel.ChannelHandlerContext; +import java.net.SocketAddress; + /** * SSL handler assuming the endpoint is a SSL tunnel and not a Postgres endpoint. */ @@ -26,8 +28,8 @@ final class SSLTunnelHandlerAdapter extends AbstractPostgresSSLHandlerAdapter { private final SSLConfig sslConfig; - SSLTunnelHandlerAdapter(ByteBufAllocator alloc, SSLConfig sslConfig) { - super(alloc, sslConfig); + SSLTunnelHandlerAdapter(ByteBufAllocator alloc, SocketAddress socketAddress, SSLConfig sslConfig) { + super(alloc, socketAddress, sslConfig); this.sslConfig = sslConfig; } diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderUnitTests.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderUnitTests.java index 7a1b36bc..3aabe644 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderUnitTests.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderUnitTests.java @@ -28,6 +28,8 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import javax.net.ssl.SSLParameters; +import java.net.InetSocketAddress; import java.time.Duration; import java.util.Arrays; import java.util.HashMap; @@ -58,6 +60,7 @@ import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.SSL_KEY; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.SSL_MODE; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.SSL_ROOT_CERT; +import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.SSL_SNI; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.STATEMENT_TIMEOUT; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.TARGET_SERVER_TYPE; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.TCP_KEEPALIVE; @@ -446,6 +449,27 @@ void shouldApplySslContextBuilderCustomizer() { .build()); assertThatIllegalStateException().isThrownBy(() -> factory.getConfiguration().getSslConfig().getSslProvider().get()).withMessageContaining("Works!"); + + SSLParameters sslParameters = factory.getConfiguration().getSslConfig().getSslParametersFactory().apply(InetSocketAddress.createUnresolved("myhost", 1)); + + assertThat(sslParameters.getServerNames()).hasSize(1); + } + + @Test + void shouldApplySslSni() { + + PostgresqlConnectionFactory factory = this.provider.create(builder() + .option(DRIVER, POSTGRESQL_DRIVER) + .option(HOST, "test-host") + .option(PASSWORD, "test-password") + .option(USER, "test-user") + .option(SSL_MODE, SSLMode.ALLOW) + .option(SSL_SNI, false) + .build()); + + SSLParameters sslParameters = factory.getConfiguration().getSslConfig().getSslParametersFactory().apply(InetSocketAddress.createUnresolved("myhost", 1)); + + assertThat(sslParameters.getServerNames()).isNull(); } @Test