-
Notifications
You must be signed in to change notification settings - Fork 176
/
SASLAuthenticationHandler.java
145 lines (117 loc) · 5.95 KB
/
SASLAuthenticationHandler.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package io.r2dbc.postgresql.authentication;
import com.ongres.scram.client.ScramClient;
import com.ongres.scram.common.StringPreparation;
import com.ongres.scram.common.exception.ScramException;
import com.ongres.scram.common.util.TlsServerEndpoint;
import io.r2dbc.postgresql.client.ConnectionContext;
import io.r2dbc.postgresql.message.backend.AuthenticationMessage;
import io.r2dbc.postgresql.message.backend.AuthenticationSASL;
import io.r2dbc.postgresql.message.backend.AuthenticationSASLContinue;
import io.r2dbc.postgresql.message.backend.AuthenticationSASLFinal;
import io.r2dbc.postgresql.message.frontend.FrontendMessage;
import io.r2dbc.postgresql.message.frontend.SASLInitialResponse;
import io.r2dbc.postgresql.message.frontend.SASLResponse;
import io.r2dbc.postgresql.util.Assert;
import io.r2dbc.postgresql.util.ByteBufferUtils;
import reactor.core.Exceptions;
import reactor.util.Logger;
import reactor.util.Loggers;
import reactor.util.annotation.Nullable;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
public class SASLAuthenticationHandler implements AuthenticationHandler {
private static final Logger LOG = Loggers.getLogger(SASLAuthenticationHandler.class);
private final CharSequence password;
private final String username;
private final ConnectionContext context;
private ScramClient scramClient;
/**
* Create a new handler.
*
* @param password the password to use for authentication
* @param username the username to use for authentication
* @param context the connection context
* @throws IllegalArgumentException if {@code password} or {@code user} is {@code null}
*/
public SASLAuthenticationHandler(CharSequence password, String username, ConnectionContext context) {
this.password = Assert.requireNonNull(password, "password must not be null");
this.username = Assert.requireNonNull(username, "username must not be null");
this.context = Assert.requireNonNull(context, "context must not be null");
}
/**
* Returns whether this {@link AuthenticationHandler} can support authentication for a given authentication message response.
*
* @param message the message to inspect
* @return whether this {@link AuthenticationHandler} can support authentication for a given authentication message response
* @throws IllegalArgumentException if {@code message} is {@code null}
*/
public static boolean supports(AuthenticationMessage message) {
Assert.requireNonNull(message, "message must not be null");
return message instanceof AuthenticationSASL || message instanceof AuthenticationSASLContinue || message instanceof AuthenticationSASLFinal;
}
@Override
public FrontendMessage handle(AuthenticationMessage message) {
if (message instanceof AuthenticationSASL) {
return handleAuthenticationSASL((AuthenticationSASL) message);
}
if (message instanceof AuthenticationSASLContinue) {
return handleAuthenticationSASLContinue((AuthenticationSASLContinue) message);
}
if (message instanceof AuthenticationSASLFinal) {
return handleAuthenticationSASLFinal((AuthenticationSASLFinal) message);
}
throw new IllegalArgumentException(String.format("Cannot handle %s message", message.getClass().getSimpleName()));
}
private FrontendMessage handleAuthenticationSASL(AuthenticationSASL message) {
char[] password = new char[this.password.length()];
for (int i = 0; i < password.length; i++) {
password[i] = this.password.charAt(i);
}
ScramClient.FinalBuildStage builder = ScramClient.builder()
.advertisedMechanisms(message.getAuthenticationMechanisms())
.username(this.username) // ignored by the server, use startup message
.password(password)
.stringPreparation(StringPreparation.POSTGRESQL_PREPARATION);
SSLSession sslSession = this.context.getSslSession();
if (sslSession != null && sslSession.isValid()) {
builder.channelBinding(TlsServerEndpoint.TLS_SERVER_END_POINT, extractSslEndpoint(sslSession));
}
this.scramClient = builder.build();
return new SASLInitialResponse(ByteBufferUtils.encode(this.scramClient.clientFirstMessage().toString()), this.scramClient.getScramMechanism().getName());
}
private static byte[] extractSslEndpoint(SSLSession sslSession) {
try {
Certificate[] certificates = sslSession.getPeerCertificates();
if (certificates != null && certificates.length > 0) {
Certificate peerCert = certificates[0]; // First certificate is the peer's certificate
if (peerCert instanceof X509Certificate) {
X509Certificate cert = (X509Certificate) peerCert;
return TlsServerEndpoint.getChannelBindingData(cert);
}
}
} catch (CertificateException | SSLException e) {
LOG.debug("Cannot extract X509Certificate from SSL session", e);
}
return new byte[0];
}
private FrontendMessage handleAuthenticationSASLContinue(AuthenticationSASLContinue message) {
try {
this.scramClient.serverFirstMessage(ByteBufferUtils.decode(message.getData()));
return new SASLResponse(ByteBufferUtils.encode(this.scramClient.clientFinalMessage().toString()));
} catch (ScramException e) {
throw Exceptions.propagate(e);
}
}
@Nullable
private FrontendMessage handleAuthenticationSASLFinal(AuthenticationSASLFinal message) {
try {
this.scramClient.serverFinalMessage(ByteBufferUtils.decode(message.getAdditionalData()));
return null;
} catch (ScramException e) {
throw Exceptions.propagate(e);
}
}
}