Skip to content

Commit

Permalink
Add support for dynamic usernames and passwords.
Browse files Browse the repository at this point in the history
[closes #613]

Signed-off-by: Mark Paluch <mpaluch@paluch.biz>
  • Loading branch information
mp911de committed Dec 7, 2023
1 parent fc546e4 commit 6110fad
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import io.r2dbc.postgresql.message.backend.NoticeResponse;
import io.r2dbc.postgresql.util.Assert;
import io.r2dbc.postgresql.util.LogLevel;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;
import reactor.netty.resources.LoopResources;
import reactor.util.annotation.Nullable;

Expand Down Expand Up @@ -103,7 +105,7 @@ public final class PostgresqlConnectionConfiguration {

private final Map<String, String> options;

private final CharSequence password;
private final Publisher<CharSequence> password;

private final boolean preferAttachedBuffers;

Expand All @@ -123,18 +125,18 @@ public final class PostgresqlConnectionConfiguration {

private final TimeZone timeZone;

private final String username;
private final Publisher<String> username;

private PostgresqlConnectionConfiguration(String applicationName, boolean autodetectExtensions, @Nullable boolean compatibilityMode, @Nullable Duration connectTimeout, @Nullable String database,
LogLevel errorResponseLogLevel,
List<Extension> extensions, ToIntFunction<String> fetchSize, boolean forceBinary, @Nullable Duration lockWaitTimeout,
@Nullable LoopResources loopResources,
@Nullable MultiHostConfiguration multiHostConfiguration,
LogLevel noticeLogLevel, @Nullable Map<String, String> options, @Nullable CharSequence password, boolean preferAttachedBuffers,
LogLevel noticeLogLevel, @Nullable Map<String, String> options, Publisher<CharSequence> password, boolean preferAttachedBuffers,
int preparedStatementCacheQueries, @Nullable String schema,
@Nullable SingleHostConfiguration singleHostConfiguration, SSLConfig sslConfig, @Nullable Duration statementTimeout,
boolean tcpKeepAlive, boolean tcpNoDelay, TimeZone timeZone,
String username) {
Publisher<String> username) {
this.applicationName = Assert.requireNonNull(applicationName, "applicationName must not be null");
this.autodetectExtensions = autodetectExtensions;
this.compatibilityMode = compatibilityMode;
Expand Down Expand Up @@ -200,7 +202,7 @@ public String toString() {
", multiHostConfiguration='" + this.multiHostConfiguration + '\'' +
", noticeLogLevel='" + this.noticeLogLevel + '\'' +
", options='" + this.options + '\'' +
", password='" + obfuscate(this.password != null ? this.password.length() : 0) + '\'' +
", password='" + obfuscate(this.password != null ? 4 : 0) + '\'' +
", preferAttachedBuffers=" + this.preferAttachedBuffers +
", singleHostConfiguration=" + this.singleHostConfiguration +
", statementTimeout=" + this.statementTimeout +
Expand Down Expand Up @@ -261,8 +263,7 @@ Map<String, String> getOptions() {
return Collections.unmodifiableMap(this.options);
}

@Nullable
CharSequence getPassword() {
Publisher<CharSequence> getPassword() {
return this.password;
}

Expand Down Expand Up @@ -290,7 +291,7 @@ SingleHostConfiguration getRequiredSingleHostConfiguration() {
return config;
}

String getUsername() {
Publisher<String> getUsername() {
return this.username;
}

Expand Down Expand Up @@ -380,7 +381,7 @@ public static final class Builder {
private Map<String, String> options;

@Nullable
private CharSequence password;
private Publisher<CharSequence> password;

private boolean preferAttachedBuffers = false;

Expand Down Expand Up @@ -423,7 +424,7 @@ public static final class Builder {
private LoopResources loopResources = null;

@Nullable
private String username;
private Publisher<String> username;

private Builder() {
}
Expand Down Expand Up @@ -743,7 +744,31 @@ public Builder options(Map<String, String> options) {
* @return this {@link Builder}
*/
public Builder password(@Nullable CharSequence password) {
this.password = password;
this.password = Mono.justOrEmpty(password);
return this;
}

/**
* Configure the password publisher. The publisher is used on each authentication attempt.
*
* @param password the password
* @return this {@link Builder}
* @since 1.0.3
*/
public Builder password(Publisher<CharSequence> password) {
this.password = Mono.from(password);
return this;
}

/**
* Configure the password supplier. The supplier is used on each authentication attempt.
*
* @param password the password
* @return this {@link Builder}
* @since 1.0.3
*/
public Builder password(Supplier<CharSequence> password) {
this.password = Mono.fromSupplier(password);
return this;
}

Expand Down Expand Up @@ -780,7 +805,6 @@ public Builder preferAttachedBuffers(boolean preferAttachedBuffers) {
*
* @param preparedStatementCacheQueries the preparedStatementCacheQueries
* @return this {@link Builder}
* @throws IllegalArgumentException if {@code username} is {@code null}
* @since 0.8.1
*/
public Builder preparedStatementCacheQueries(int preparedStatementCacheQueries) {
Expand Down Expand Up @@ -1023,10 +1047,34 @@ public Builder timeZone(TimeZone timeZone) {
* @throws IllegalArgumentException if {@code username} is {@code null}
*/
public Builder username(String username) {
this.username = Mono.just(Assert.requireNonNull(username, "username must not be null"));
return this;
}

/**
* Configure the username publisher. The publisher is used on each authentication attempt.
*
* @param username the username
* @return this {@link Builder}
* @throws IllegalArgumentException if {@code username} is {@code null}
*/
public Builder username(Publisher<String> username) {
this.username = Assert.requireNonNull(username, "username must not be null");
return this;
}

/**
* Configure the username supplier. The supplier is used on each authentication attempt.
*
* @param username the username
* @return this {@link Builder}
* @throws IllegalArgumentException if {@code username} is {@code null}
*/
public Builder username(Supplier<String> username) {
this.username = Mono.fromSupplier(Assert.requireNonNull(username, "username must not be null"));
return this;
}

@Override
public String toString() {
return "Builder{" +
Expand All @@ -1044,7 +1092,7 @@ public String toString() {
", multiHostConfiguration='" + this.multiHostConfiguration + '\'' +
", noticeLogLevel='" + this.noticeLogLevel + '\'' +
", parameters='" + this.options + '\'' +
", password='" + obfuscate(this.password != null ? this.password.length() : 0) + '\'' +
", password='" + obfuscate(this.password != null ? 4 : 0) + '\'' +
", preparedStatementCacheQueries='" + this.preparedStatementCacheQueries + '\'' +
", schema='" + this.schema + '\'' +
", singleHostConfiguration='" + this.singleHostConfiguration + '\'' +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.r2dbc.spi.ConnectionFactoryOptions;
import io.r2dbc.spi.ConnectionFactoryProvider;
import io.r2dbc.spi.Option;
import org.reactivestreams.Publisher;
import reactor.netty.resources.LoopResources;

import javax.net.ssl.HostnameVerifier;
Expand All @@ -37,6 +38,7 @@
import java.util.Map;
import java.util.TimeZone;
import java.util.function.Function;
import java.util.function.Supplier;

import static io.r2dbc.spi.ConnectionFactoryOptions.CONNECT_TIMEOUT;
import static io.r2dbc.spi.ConnectionFactoryOptions.DATABASE;
Expand Down Expand Up @@ -290,6 +292,7 @@ public boolean supports(ConnectionFactoryOptions connectionFactoryOptions) {
* @return this {@link PostgresqlConnectionConfiguration.Builder}
* @throws IllegalArgumentException if {@code options} is {@code null}
*/
@SuppressWarnings("unchecked")
private static PostgresqlConnectionConfiguration.Builder fromConnectionFactoryOptions(ConnectionFactoryOptions options) {

Assert.requireNonNull(options, "connectionFactoryOptions must not be null");
Expand Down Expand Up @@ -344,7 +347,6 @@ private static PostgresqlConnectionConfiguration.Builder fromConnectionFactoryOp
mapper.fromTyped(LOOP_RESOURCES).to(builder::loopResources);
mapper.from(NOTICE_LOG_LEVEL).map(it -> OptionMapper.toEnum(it, LogLevel.class)).to(builder::noticeLogLevel);
mapper.from(OPTIONS).map(PostgresqlConnectionFactoryProvider::convertToMap).to(builder::options);
mapper.fromTyped(PASSWORD).to(builder::password);
mapper.from(PORT).map(OptionMapper::toInteger).to(builder::port);
mapper.from(PREFER_ATTACHED_BUFFERS).map(OptionMapper::toBoolean).to(builder::preferAttachedBuffers);
mapper.from(PREPARED_STATEMENT_CACHE_QUERIES).map(OptionMapper::toInteger).to(builder::preparedStatementCacheQueries);
Expand All @@ -363,7 +365,26 @@ private static PostgresqlConnectionConfiguration.Builder fromConnectionFactoryOp

return TimeZone.getTimeZone(it.toString());
}).to(builder::timeZone);
builder.username("" + options.getRequiredValue(USER));

Object user = options.getRequiredValue(USER);
Object password = options.getValue(PASSWORD);

if (user instanceof Supplier) {
builder.username((Supplier<String>) user);
} else if (user instanceof Publisher) {
builder.username((Publisher<String>) user);
} else {
builder.username("" + user);
}
if (password != null) {
if (password instanceof Supplier) {
builder.password((Supplier<CharSequence>) password);
} else if (password instanceof Publisher) {
builder.password((Publisher<CharSequence>) password);
} else {
builder.password((CharSequence) password);
}
}

return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.r2dbc.postgresql.message.backend.AuthenticationMessage;
import io.r2dbc.postgresql.util.Assert;
import reactor.core.publisher.Mono;
import reactor.util.annotation.Nullable;

import java.net.SocketAddress;

Expand All @@ -44,26 +45,54 @@ final class SingleHostConnectionFunction implements ConnectionFunction {
public Mono<Client> connect(SocketAddress endpoint, ConnectionSettings settings) {

return this.upstreamFunction.connect(endpoint, settings)
.delayUntil(client -> StartupMessageFlow
.exchange(this::getAuthenticationHandler, client, this.configuration.getDatabase(), this.configuration.getUsername(),
getParameterProvider(this.configuration, settings))
.delayUntil(client -> getCredentials().flatMapMany(credentials -> StartupMessageFlow
.exchange(auth -> getAuthenticationHandler(auth, credentials), client, this.configuration.getDatabase(), credentials.getUsername(),
getParameterProvider(this.configuration, settings)))
.handle(ExceptionFactory.INSTANCE::handleErrorResponse));
}

private static PostgresStartupParameterProvider getParameterProvider(PostgresqlConnectionConfiguration configuration, ConnectionSettings settings) {
return new PostgresStartupParameterProvider(configuration.getApplicationName(), configuration.getTimeZone(), settings);
}

protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message) {
protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message, UsernameAndPassword usernameAndPassword) {
if (PasswordAuthenticationHandler.supports(message)) {
CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null");
return new PasswordAuthenticationHandler(password, this.configuration.getUsername());
CharSequence password = Assert.requireNonNull(usernameAndPassword.getPassword(), "Password must not be null");
return new PasswordAuthenticationHandler(password, usernameAndPassword.getUsername());
} else if (SASLAuthenticationHandler.supports(message)) {
CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null");
return new SASLAuthenticationHandler(password, this.configuration.getUsername());
CharSequence password = Assert.requireNonNull(usernameAndPassword.getPassword(), "Password must not be null");
return new SASLAuthenticationHandler(password, usernameAndPassword.getUsername());
} else {
throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message));
}
}

Mono<UsernameAndPassword> getCredentials() {

return Mono.zip(Mono.from(this.configuration.getUsername()).single(), Mono.from(this.configuration.getPassword()).singleOptional()).map(it -> {
return new UsernameAndPassword(it.getT1(), it.getT2().orElse(null));
});
}

static class UsernameAndPassword {

final String username;

final @Nullable CharSequence password;

public UsernameAndPassword(String username, @Nullable CharSequence password) {
this.username = username;
this.password = password;
}

public String getUsername() {
return this.username;
}

@Nullable
public CharSequence getPassword() {
return this.password;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void builderHostAndSocket() {

@Test
void builderNoUsername() {
assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlConnectionConfiguration.builder().username(null))
assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlConnectionConfiguration.builder().username((String) null))
.withMessage("username must not be null");
}

Expand Down Expand Up @@ -84,9 +84,7 @@ void configuration() {
.hasFieldOrPropertyWithValue("database", "test-database")
.hasFieldOrPropertyWithValue("singleHostConfiguration.host", "test-host")
.hasFieldOrProperty("options")
.hasFieldOrPropertyWithValue("password", null)
.hasFieldOrPropertyWithValue("singleHostConfiguration.port", 100)
.hasFieldOrPropertyWithValue("username", "test-username")
.hasFieldOrProperty("sslConfig")
.hasFieldOrPropertyWithValue("tcpKeepAlive", true)
.hasFieldOrPropertyWithValue("tcpNoDelay", false)
Expand Down Expand Up @@ -116,9 +114,7 @@ void configureStatementAndLockTimeouts() {
.hasFieldOrPropertyWithValue("database", "test-database")
.hasFieldOrPropertyWithValue("singleHostConfiguration.host", "test-host")
.hasFieldOrProperty("options")
.hasFieldOrPropertyWithValue("password", null)
.hasFieldOrPropertyWithValue("singleHostConfiguration.port", 100)
.hasFieldOrPropertyWithValue("username", "test-username")
.hasFieldOrProperty("sslConfig")
.hasFieldOrPropertyWithValue("tcpKeepAlive", true)
.hasFieldOrPropertyWithValue("tcpNoDelay", false)
Expand Down Expand Up @@ -160,10 +156,8 @@ void configurationDefaults() {
.hasFieldOrPropertyWithValue("applicationName", "r2dbc-postgresql")
.hasFieldOrPropertyWithValue("database", "test-database")
.hasFieldOrPropertyWithValue("singleHostConfiguration.host", "test-host")
.hasFieldOrPropertyWithValue("password", "test-password")
.hasFieldOrPropertyWithValue("singleHostConfiguration.port", 5432)
.hasFieldOrProperty("options")
.hasFieldOrPropertyWithValue("username", "test-username")
.hasFieldOrProperty("sslConfig")
.hasFieldOrPropertyWithValue("tcpKeepAlive", false)
.hasFieldOrPropertyWithValue("tcpNoDelay", true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import io.r2dbc.spi.ConnectionFactoryOptions;
import io.r2dbc.spi.Option;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import java.time.Duration;
import java.util.Arrays;
Expand All @@ -33,6 +35,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.TimeZone;
import java.util.function.Supplier;

import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.AUTODETECT_EXTENSIONS;
import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.COMPATIBILITY_MODE;
Expand Down Expand Up @@ -617,6 +620,34 @@ void shouldConfigureExtensions() {
assertThat(factory.getConfiguration().getExtensions()).containsExactly(testExtension1, testExtension2);
}

@Test
void supportsUsernameAndPasswordSupplier() {
PostgresqlConnectionFactory factory = this.provider.create(builder()
.option(DRIVER, LEGACY_POSTGRESQL_DRIVER)
.option(HOST, "test-host")
.option(Option.valueOf("password"), (Supplier<String>) () -> "test-password")
.option(Option.valueOf("user"), (Supplier<String>) () -> "test-user")
.option(USER, "test-user")
.build());

StepVerifier.create(factory.getConfiguration().getPassword()).expectNext("test-password").verifyComplete();
StepVerifier.create(factory.getConfiguration().getUsername()).expectNext("test-user").verifyComplete();
}

@Test
void supportsUsernameAndPasswordPublisher() {
PostgresqlConnectionFactory factory = this.provider.create(builder()
.option(DRIVER, LEGACY_POSTGRESQL_DRIVER)
.option(HOST, "test-host")
.option(Option.valueOf("password"), Mono.just("test-password"))
.option(Option.valueOf("user"), Mono.just("test-user"))
.option(USER, "test-user")
.build());

StepVerifier.create(factory.getConfiguration().getPassword()).expectNext("test-password").verifyComplete();
StepVerifier.create(factory.getConfiguration().getUsername()).expectNext("test-user").verifyComplete();
}

private static class TestExtension implements Extension {

private final String name;
Expand Down

0 comments on commit 6110fad

Please sign in to comment.