Skip to content

Commit

Permalink
Set SNI on SSL connections.
Browse files Browse the repository at this point in the history
[resolves #634]
  • Loading branch information
mp911de committed Feb 15, 2024
1 parent ccbcca2 commit f1757df
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 14 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -103,6 +103,7 @@ Optional)_
| `sslCert` | Path to SSL certificate for TLS authentication in PEM format. Can be also a resource path. _(Optional)_
| `sslPassword` | Key password to decrypt SSL key. _(Optional)_
| `sslHostnameVerifier` | `javax.net.ssl.HostnameVerifier` implementation. _(Optional)_
| `sslSni` | Enable/disable SNI to send the configured `host` name during the SSL handshake. Defaults to `true`. _(Optional)_
| `statementTimeout` | Statement timeout. _(Optional)_
| `targetServerType` | Type of server to use when using multi-host operations. Supported values: `ANY`, `PRIMARY`, `SECONDARY`, `PREFER_SECONDARY`. Defaults to `ANY`. _(Optional)_
| `tcpNoDelay` | Enable/disable TCP NoDelay. Enabled by default. _(Optional)_
Expand Down
Expand Up @@ -42,12 +42,18 @@
import reactor.util.annotation.Nullable;

import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLParameters;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.net.MalformedURLException;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.URL;
import java.time.Duration;
import java.util.ArrayList;
Expand Down Expand Up @@ -411,6 +417,12 @@ public static final class Builder {

private Function<SslContextBuilder, SslContextBuilder> sslContextBuilderCustomizer = Function.identity();

private Function<SSLEngine, SSLEngine> sslEngineCustomizer = Function.identity();

private Function<SocketAddress, SSLParameters> sslParametersFactory = it -> new SSLParameters();

private boolean sslSni = true;

private boolean tcpKeepAlive = false;

private boolean tcpNoDelay = true;
Expand Down Expand Up @@ -476,7 +488,7 @@ public PostgresqlConnectionConfiguration build() {
this.extensions, this.fetchSize, this.forceBinary, this.lockWaitTimeout, this.loopResources, multiHostConfiguration,
this.noticeLogLevel, this.options, this.password, this.preferAttachedBuffers,
this.preparedStatementCacheQueries, this.schema, singleHostConfiguration,
this.createSslConfig(), this.statementTimeout, this.tcpKeepAlive, this.tcpNoDelay, this.timeZone, this.username);
this.createSslConfig(this.sslSni), this.statementTimeout, this.tcpKeepAlive, this.tcpNoDelay, this.timeZone, this.username);
}

/**
Expand Down Expand Up @@ -852,6 +864,47 @@ public Builder sslContextBuilderCustomizer(Function<SslContextBuilder, SslContex
return this;
}

/**
* Configure a {@link SSLEngine} customizer. The customizer gets applied on each SSL connection attempt to allow for just-in-time configuration updates. The {@link Function} gets
* called with a {@link SSLEngine} instance that has all configuration options applied. The customizer may return the same builder or return a new builder instance to be used to
* build the SSL context.
*
* @param sslEngineCustomizer customizer function
* @return this {@link Builder}
* @throws IllegalArgumentException if {@code sslEngineCustomizer} is {@code null}
*/
public Builder sslEngineCustomizer(Function<SSLEngine, SSLEngine> sslEngineCustomizer) {
this.sslEngineCustomizer = Assert.requireNonNull(sslEngineCustomizer, "sslEngineCustomizer must not be null");
return this;
}

/**
* Configure a {@link SSLParameters} provider for a given {@link SocketAddress}. The provider gets applied on each SSL connection attempt to allow for just-in-time configuration updates.
* Typically used to configure SSL protocols
*
* @param sslParametersFactory customizer function
* @return this {@link Builder}
* @throws IllegalArgumentException if {@code sslParametersFactory} is {@code null}
* @since 1.0.4
*/
public Builder sslParameters(Function<SocketAddress, SSLParameters> sslParametersFactory) {
this.sslParametersFactory = Assert.requireNonNull(sslParametersFactory, "sslParametersFactory must not be null");
return this;
}

/**
* Configure whether to indicate the hostname and port via SNI to the server. Enabled by default.
*
* @param sslSni whether to indicate the hostname and port via SNI. Sets {@link SSLParameters#setServerNames(List)} on the {@link SSLParameters} instance provided by
* {@link #sslParameters(Function)}.
* @return this {@link Builder}
* @since 1.0.4
*/
public Builder sslSni(boolean sslSni) {
this.sslSni = sslSni;
return this;
}

/**
* Configure ssl cert for client certificate authentication. Can point to either a resource within the classpath or a file.
*
Expand Down Expand Up @@ -1094,10 +1147,13 @@ public String toString() {
", schema='" + this.schema + '\'' +
", singleHostConfiguration='" + this.singleHostConfiguration + '\'' +
", sslContextBuilderCustomizer='" + this.sslContextBuilderCustomizer + '\'' +
", sslEngineCustomizer='" + this.sslEngineCustomizer + '\'' +
", sslParametersFactory='" + this.sslParametersFactory + '\'' +
", sslMode='" + this.sslMode + '\'' +
", sslRootCert='" + this.sslRootCert + '\'' +
", sslCert='" + this.sslCert + '\'' +
", sslKey='" + this.sslKey + '\'' +
", sslSni='" + this.sslSni + '\'' +
", statementTimeout='" + this.statementTimeout + '\'' +
", sslHostnameVerifier='" + this.sslHostnameVerifier + '\'' +
", tcpKeepAlive='" + this.tcpKeepAlive + '\'' +
Expand All @@ -1107,15 +1163,47 @@ public String toString() {
'}';
}

private SSLConfig createSslConfig() {
private SSLConfig createSslConfig(boolean sslSni) {
if (this.singleHostConfiguration != null && this.singleHostConfiguration.getSocket() != null || this.sslMode == SSLMode.DISABLE) {
return SSLConfig.disabled();
}

HostnameVerifier hostnameVerifier = this.sslHostnameVerifier;
return new SSLConfig(this.sslMode, createSslProvider(), hostnameVerifier);
Function<SocketAddress, SSLParameters> sslParametersFunctionToUse = getSslParametersFactory(sslSni, this.sslParametersFactory);
return new SSLConfig(this.sslMode, createSslProvider(), this.sslEngineCustomizer, sslParametersFunctionToUse, this.sslHostnameVerifier);
}

private static Function<SocketAddress, SSLParameters> getSslParametersFactory(boolean sslSni, Function<SocketAddress, SSLParameters> sslParametersFunction) {
if (!sslSni) {
return sslParametersFunction;
}

return socket -> {

SSLParameters sslParameters = sslParametersFunction.apply(socket);

if (socket instanceof InetSocketAddress) {

InetSocketAddress inetSocketAddress = (InetSocketAddress) socket;
String hostString = inetSocketAddress.getHostString();
if (SSLConfig.isValidSniHostname(hostString)) {
appendSniHost(sslParameters, hostString);
}
}

return sslParameters;
};
}

private static void appendSniHost(SSLParameters sslParameters, String hostString) {

List<SNIServerName> existingServerNames = sslParameters.getServerNames();
List<SNIServerName> serverNames = existingServerNames == null ? new ArrayList<>() : new ArrayList<>(existingServerNames);
serverNames.add(new SNIHostName(hostString));

sslParameters.setServerNames(serverNames);
}


private Supplier<SslContext> createSslProvider() {
SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
if (this.sslMode.verifyCertificate()) {
Expand Down
Expand Up @@ -221,6 +221,13 @@ public final class PostgresqlConnectionFactoryProvider implements ConnectionFact
*/
public static final Option<String> SSL_ROOT_CERT = Option.valueOf("sslRootCert");

/**
* Configure whether to use SNI on SSL connections. Enabled by default.
*
* @since 1.0.4
*/
public static final Option<Boolean> SSL_SNI = Option.valueOf("sslSni");

/**
* Statement timeout.
*
Expand Down Expand Up @@ -406,6 +413,7 @@ private static void setupSsl(PostgresqlConnectionConfiguration.Builder builder,
mapper.fromTyped(SSL_KEY).to(builder::sslKey);
mapper.fromTyped(SSL_ROOT_CERT).to(builder::sslRootCert);
mapper.fromTyped(SSL_PASSWORD).to(builder::sslPassword);
mapper.fromTyped(SSL_SNI).map(OptionMapper::toBoolean).to(builder::sslSni);

mapper.from(SSL_HOSTNAME_VERIFIER).map(it -> {

Expand Down
Expand Up @@ -28,7 +28,9 @@
import reactor.core.publisher.Mono;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.concurrent.CompletableFuture;

abstract class AbstractPostgresSSLHandlerAdapter extends ChannelInboundHandlerAdapter implements GenericFutureListener<Future<Channel>> {
Expand All @@ -41,9 +43,14 @@ abstract class AbstractPostgresSSLHandlerAdapter extends ChannelInboundHandlerAd

private final CompletableFuture<Void> handshakeFuture;

AbstractPostgresSSLHandlerAdapter(ByteBufAllocator alloc, SSLConfig sslConfig) {
AbstractPostgresSSLHandlerAdapter(ByteBufAllocator alloc, SocketAddress socketAddress, SSLConfig sslConfig) {
this.sslConfig = sslConfig;
this.sslEngine = sslConfig.getSslProvider().get().newEngine(alloc);

SSLEngine sslEngine = sslConfig.getSslProvider().get().newEngine(alloc);
SSLParameters sslParameters = sslConfig.getSslParametersFactory().apply(socketAddress);
sslEngine.setSSLParameters(sslParameters);

this.sslEngine = sslConfig.getSslEngineCustomizer().apply(sslEngine);
this.handshakeFuture = new CompletableFuture<>();
this.sslHandler = new SslHandler(this.sslEngine);
this.sslHandler.handshakeFuture().addListener(this);
Expand Down
Expand Up @@ -399,21 +399,21 @@ public static Mono<ReactorNettyClient> connect(SocketAddress socketAddress, Conn
new LoggingHandler(ReactorNettyClient.class, LogLevel.TRACE));
}

registerSslHandler(settings.getSslConfig(), channel);
registerSslHandler(socketAddress, settings.getSslConfig(), channel);
}).connect().flatMap(it ->
getSslHandshake(it.channel()).thenReturn(new ReactorNettyClient(it, settings))
);
}

private static void registerSslHandler(SSLConfig sslConfig, Channel channel) {
private static void registerSslHandler(SocketAddress socketAddress, SSLConfig sslConfig, Channel channel) {
try {
if (sslConfig.getSslMode().startSsl()) {

AbstractPostgresSSLHandlerAdapter sslAdapter;
if (sslConfig.getSslMode() == SSLMode.TUNNEL) {
sslAdapter = new SSLTunnelHandlerAdapter(channel.alloc(), sslConfig);
sslAdapter = new SSLTunnelHandlerAdapter(channel.alloc(), socketAddress, sslConfig);
} else {
sslAdapter = new SSLSessionHandlerAdapter(channel.alloc(), sslConfig);
sslAdapter = new SSLSessionHandlerAdapter(channel.alloc(), socketAddress, sslConfig);
}

channel.pipeline().addFirst(sslAdapter);
Expand Down
63 changes: 63 additions & 0 deletions src/main/java/io/r2dbc/postgresql/client/SSLConfig.java
Expand Up @@ -21,6 +21,10 @@
import reactor.util.annotation.Nullable;

import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import java.net.SocketAddress;
import java.util.function.Function;
import java.util.function.Supplier;

public final class SSLConfig {
Expand All @@ -33,7 +37,16 @@ public final class SSLConfig {
@Nullable
private final Supplier<SslContext> sslProvider;

private final Function<SSLEngine, SSLEngine> sslEngineCustomizer;

private final Function<SocketAddress, SSLParameters> sslParametersFactory;

public SSLConfig(SSLMode sslMode, @Nullable Supplier<SslContext> sslProvider, @Nullable HostnameVerifier hostnameVerifier) {
this(sslMode, sslProvider, Function.identity(), it -> new SSLParameters(), hostnameVerifier);
}

public SSLConfig(SSLMode sslMode, @Nullable Supplier<SslContext> sslProvider, Function<SSLEngine, SSLEngine> sslEngineCustomizer, Function<SocketAddress, SSLParameters> sslParametersFactory,
@Nullable HostnameVerifier hostnameVerifier) {
if (sslMode != SSLMode.DISABLE) {
Assert.requireNonNull(sslProvider, "SslContext provider is required for ssl mode " + sslMode);
}
Expand All @@ -42,6 +55,8 @@ public SSLConfig(SSLMode sslMode, @Nullable Supplier<SslContext> sslProvider, @N
}
this.sslMode = sslMode;
this.sslProvider = sslProvider;
this.sslEngineCustomizer = sslEngineCustomizer;
this.sslParametersFactory = sslParametersFactory;
this.hostnameVerifier = hostnameVerifier;
}

Expand All @@ -64,12 +79,60 @@ public Supplier<SslContext> getSslProvider() {
return this.sslProvider;
}

public Function<SSLEngine, SSLEngine> getSslEngineCustomizer() {
return this.sslEngineCustomizer;
}


public Function<SocketAddress, SSLParameters> getSslParametersFactory() {
return this.sslParametersFactory;
}

public SSLConfig mutateMode(SSLMode newMode) {
return new SSLConfig(
newMode,
this.sslProvider,
this.sslEngineCustomizer,
this.sslParametersFactory,
this.hostnameVerifier
);
}

public static boolean isValidSniHostname(String input) {
for (int i = 0; i < input.length(); i++) {
char c = input.charAt(i);
if (isLabelSeparator(c)) {
continue;
}
if (isNonLDHAsciiCodePoint(c)) {
return false;
}
}
return true;
}

//
// LDH stands for "letter/digit/hyphen", with characters restricted to the
// 26-letter Latin alphabet <A-Z a-z>, the digits <0-9>, and the hyphen
// <->.
// Non LDH refers to characters in the ASCII range, but which are not
// letters, digits or the hypen.
//
// non-LDH = 0..0x2C, 0x2E..0x2F, 0x3A..0x40, 0x5B..0x60, 0x7B..0x7F
//
private static boolean isNonLDHAsciiCodePoint(char ch) {
return (0x0000 <= ch && ch <= 0x002C) ||
(0x002E <= ch && ch <= 0x002F) ||
(0x003A <= ch && ch <= 0x0040) ||
(0x005B <= ch && ch <= 0x0060) ||
(0x007B <= ch && ch <= 0x007F);
}

//
// to check if a character is a label separator, i.e. a dot character.
//
private static boolean isLabelSeparator(char c) {
return (c == '.' || c == '\u3002' || c == '\uFF0E' || c == '\uFF61');
}

}
Expand Up @@ -22,6 +22,8 @@
import io.r2dbc.postgresql.message.frontend.SSLRequest;
import reactor.core.publisher.Mono;

import java.net.SocketAddress;

/**
* SSL handler assuming the endpoint a Postgres endpoint following the {@link SSLRequest} flow.
*
Expand All @@ -35,8 +37,8 @@ final class SSLSessionHandlerAdapter extends AbstractPostgresSSLHandlerAdapter {

private boolean negotiating = true;

SSLSessionHandlerAdapter(ByteBufAllocator alloc, SSLConfig sslConfig) {
super(alloc, sslConfig);
SSLSessionHandlerAdapter(ByteBufAllocator alloc, SocketAddress socketAddress, SSLConfig sslConfig) {
super(alloc, socketAddress, sslConfig);
this.alloc = alloc;
this.sslConfig = sslConfig;
}
Expand Down
Expand Up @@ -19,15 +19,17 @@
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelHandlerContext;

import java.net.SocketAddress;

/**
* SSL handler assuming the endpoint is a SSL tunnel and not a Postgres endpoint.
*/
final class SSLTunnelHandlerAdapter extends AbstractPostgresSSLHandlerAdapter {

private final SSLConfig sslConfig;

SSLTunnelHandlerAdapter(ByteBufAllocator alloc, SSLConfig sslConfig) {
super(alloc, sslConfig);
SSLTunnelHandlerAdapter(ByteBufAllocator alloc, SocketAddress socketAddress, SSLConfig sslConfig) {
super(alloc, socketAddress, sslConfig);
this.sslConfig = sslConfig;
}

Expand Down

0 comments on commit f1757df

Please sign in to comment.