Skip to content

Commit

Permalink
[repository/downloader] Add support for multiple header values
Browse files Browse the repository at this point in the history
Headers can generally be specified multiple times, not just once.

This refactoring is prerequisite for adding support for getting credentials from the credential helper.

Note that this does not change the Starlark API or the qualifier for the gRPC remote downloader (both of which need to go through the process for incompatible changes - which should hopefully be pretty straight forward to do as this feature doesn't seem to be used that much AFAICT).

Progress on #15856

Closes #16260.

PiperOrigin-RevId: 474770302
Change-Id: I74620e36481414a41108991bc61321d104a6d39a
  • Loading branch information
Yannic authored and Copybara-Service committed Sep 16, 2022
1 parent daabe50 commit a013ad4
Show file tree
Hide file tree
Showing 16 changed files with 171 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void setDelegate(@Nullable Downloader delegate) {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, String>> authHeaders,
Map<URI, Map<String, List<String>>> authHeaders,
Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public void setNetrcCreds(Credentials netrcCreds) {
*/
public Path download(
List<URL> originalUrls,
Map<URI, Map<String, String>> authHeaders,
Map<URI, Map<String, List<String>>> authHeaders,
Optional<Checksum> checksum,
String canonicalId,
Optional<String> type,
Expand All @@ -133,7 +133,7 @@ public Path download(
// ctx.download{,_and_extract}, this not the case. Should be refactored to handle all .netrc
// parsing in one place, in Java code (similarly to #downloadAndReadOneUrl).
ImmutableList<URL> rewrittenUrls = ImmutableList.copyOf(originalUrls);
Map<URI, Map<String, String>> rewrittenAuthHeaders = authHeaders;
Map<URI, Map<String, List<String>>> rewrittenAuthHeaders = authHeaders;

if (rewriter != null) {
ImmutableList<UrlRewriter.RewrittenURL> rewrittenUrlMappings = rewriter.amend(originalUrls);
Expand Down Expand Up @@ -303,7 +303,7 @@ public byte[] downloadAndReadOneUrl(
if (Thread.interrupted()) {
throw new InterruptedException();
}
Map<URI, Map<String, String>> authHeaders = ImmutableMap.of();
Map<URI, Map<String, List<String>>> authHeaders = ImmutableMap.of();
ImmutableList<URL> rewrittenUrls = ImmutableList.of(originalUrl);

if (netrcCreds != null) {
Expand All @@ -314,7 +314,7 @@ public byte[] downloadAndReadOneUrl(
authHeaders =
ImmutableMap.of(
originalUrl.toURI(),
ImmutableMap.of(headers.getKey(), headers.getValue().get(0)));
ImmutableMap.of(headers.getKey(), ImmutableList.of(headers.getValue().get(0))));
}
} catch (URISyntaxException e) {
// If the credentials extraction failed, we're letting bazel try without credentials.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public interface Downloader {
*/
void download(
List<URL> urls,
Map<URI, Map<String, String>> authHeaders,
Map<URI, Map<String, List<String>>> authHeaders,
Optional<Checksum> checksum,
String canonicalId,
Path output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ private int scale(int unscaled) {
return Math.round(unscaled * timeoutScaling);
}

URLConnection connect(URL originalUrl, Function<URL, ImmutableMap<String, String>> requestHeaders)
URLConnection connect(
URL originalUrl, Function<URL, ImmutableMap<String, List<String>>> requestHeaders)
throws IOException {

if (Thread.interrupted()) {
Expand All @@ -116,13 +117,16 @@ URLConnection connect(URL originalUrl, Function<URL, ImmutableMap<String, String
COMPRESSED_EXTENSIONS.contains(HttpUtils.getExtension(url.getPath()))
|| COMPRESSED_EXTENSIONS.contains(HttpUtils.getExtension(originalUrl.getPath()));
connection.setInstanceFollowRedirects(false);
for (Map.Entry<String, String> entry : requestHeaders.apply(url).entrySet()) {
for (Map.Entry<String, List<String>> entry : requestHeaders.apply(url).entrySet()) {
if (isAlreadyCompressed && Ascii.equalsIgnoreCase(entry.getKey(), "Accept-Encoding")) {
// We're not going to ask for compression if we're downloading a file that already
// appears to be compressed.
continue;
}
connection.addRequestProperty(entry.getKey(), entry.getValue());
String key = entry.getKey();
for (String value : entry.getValue()) {
connection.addRequestProperty(key, value);
}
}
if (connection.getRequestProperty("User-Agent") == null) {
connection.setRequestProperty("User-Agent", USER_AGENT_VALUE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.devtools.build.lib.analysis.BlazeVersionInfo;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
Expand All @@ -30,6 +32,7 @@
import java.net.URL;
import java.net.URLConnection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
Expand All @@ -47,12 +50,12 @@
@ThreadSafe
final class HttpConnectorMultiplexer {

private static final ImmutableMap<String, String> REQUEST_HEADERS =
private static final ImmutableMap<String, List<String>> REQUEST_HEADERS =
ImmutableMap.of(
"Accept-Encoding",
"gzip",
ImmutableList.of("gzip"),
"User-Agent",
"Bazel/" + BlazeVersionInfo.instance().getReleaseName());
ImmutableList.of("Bazel/" + BlazeVersionInfo.instance().getReleaseName()));

private final EventHandler eventHandler;
private final HttpConnector connector;
Expand Down Expand Up @@ -84,50 +87,53 @@ public HttpStream connect(URL url, Optional<Checksum> checksum) throws IOExcepti
*
* @param url the URL to conenct to. can be: file, http, or https
* @param checksum checksum lazily checked on entire payload, or empty to disable
* @return an {@link InputStream} of response payload
* @param authHeaders the authentication headers
* @param type extension, e.g. "tar.gz" to force on downloaded filename, or empty to not do this
* @return an {@link InputStream} of response payload
* @throws IOException if all mirrors are down and contains suppressed exception of each attempt
* @throws InterruptedIOException if current thread is being cast into oblivion
* @throws IllegalArgumentException if {@code urls} is empty or has an unsupported protocol
*/
public HttpStream connect(
URL url,
Optional<Checksum> checksum,
Map<URI, Map<String, String>> authHeaders,
Map<URI, Map<String, List<String>>> authHeaders,
Optional<String> type)
throws IOException {
Preconditions.checkArgument(HttpUtils.isUrlSupportedByDownloader(url));
if (Thread.interrupted()) {
throw new InterruptedIOException();
}
Function<URL, ImmutableMap<String, String>> headerFunction =
Function<URL, ImmutableMap<String, List<String>>> headerFunction =
getHeaderFunction(REQUEST_HEADERS, authHeaders);
URLConnection connection = connector.connect(url, headerFunction);
return httpStreamFactory.create(
connection,
url,
checksum,
(Throwable cause, ImmutableMap<String, String> extraHeaders) -> {
(Throwable cause, ImmutableMap<String, List<String>> extraHeaders) -> {
eventHandler.handle(
Event.progress(String.format("Lost connection for %s due to %s", url, cause)));
return connector.connect(
connection.getURL(),
newUrl ->
new ImmutableMap.Builder<String, String>()
new ImmutableMap.Builder<String, List<String>>()
.putAll(headerFunction.apply(newUrl))
.putAll(extraHeaders)
.buildOrThrow());
},
type);
}

public static Function<URL, ImmutableMap<String, String>> getHeaderFunction(
Map<String, String> baseHeaders, Map<URI, Map<String, String>> additionalHeaders) {
@VisibleForTesting
static Function<URL, ImmutableMap<String, List<String>>> getHeaderFunction(
Map<String, List<String>> baseHeaders,
Map<URI, Map<String, List<String>>> additionalHeaders) {
return url -> {
ImmutableMap<String, String> headers = ImmutableMap.copyOf(baseHeaders);
ImmutableMap<String, List<String>> headers = ImmutableMap.copyOf(baseHeaders);
try {
if (additionalHeaders.containsKey(url.toURI())) {
Map<String, String> newHeaders = new HashMap<>(headers);
Map<String, List<String>> newHeaders = new HashMap<>(headers);
newHeaders.putAll(additionalHeaders.get(url.toURI()));
headers = ImmutableMap.copyOf(newHeaders);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public void setTimeoutScaling(float timeoutScaling) {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, String>> authHeaders,
Map<URI, Map<String, List<String>>> authHeaders,
Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand Down Expand Up @@ -132,7 +132,7 @@ public void download(
/** Downloads the contents of one URL and reads it into a byte array. */
public byte[] downloadAndReadOneUrl(
URL url,
Map<URI, Map<String, String>> authHeaders,
Map<URI, Map<String, List<String>>> authHeaders,
ExtendedEventHandler eventHandler,
Map<String, String> clientEnv)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadCompatible;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.net.SocketTimeoutException;
import java.net.URLConnection;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
Expand All @@ -38,11 +40,9 @@ class RetryingInputStream extends InputStream {

/** Lambda for establishing a connection. */
interface Reconnector {

/** Establishes a connection with the same parameters as what was passed to us initially. */
URLConnection connect(
Throwable cause, ImmutableMap<String, String> extraHeaders)
throws IOException;
URLConnection connect(Throwable cause, ImmutableMap<String, List<String>> extraHeaders)
throws IOException;
}

private volatile InputStream delegate;
Expand Down Expand Up @@ -117,11 +117,12 @@ private void reconnectWhereWeLeftOff(IOException cause) throws IOException {
URLConnection connection;
long amountRead = toto.get();
if (amountRead == 0) {
connection = reconnector.connect(cause, ImmutableMap.<String, String>of());
connection = reconnector.connect(cause, ImmutableMap.of());
} else {
connection =
reconnector.connect(
cause, ImmutableMap.of("Range", String.format("bytes=%d-", amountRead)));
cause,
ImmutableMap.of("Range", ImmutableList.of(String.format("bytes=%d-", amountRead))));
if (!Strings.nullToEmpty(connection.getHeaderField("Content-Range"))
.startsWith(String.format("bytes %d-", amountRead))) {
throw new IOException(String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ public ImmutableList<RewrittenURL> amend(List<URL> urls) {
* @param authHeaders A map of the URLs and their corresponding auth tokens.
* @return A map of the updated authentication headers.
*/
public Map<URI, Map<String, String>> updateAuthHeaders(
List<RewrittenURL> urls, Map<URI, Map<String, String>> authHeaders, Credentials netrcCreds) {
Map<URI, Map<String, String>> updatedAuthHeaders = new HashMap<>(authHeaders);
public Map<URI, Map<String, List<String>>> updateAuthHeaders(
List<RewrittenURL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials netrcCreds) {
Map<URI, Map<String, List<String>>> updatedAuthHeaders = new HashMap<>(authHeaders);

for (RewrittenURL url : urls) {
// if URL was not re-written by UrlRewriter in first place, we should not attach auth headers
Expand All @@ -142,7 +144,8 @@ public Map<URI, Map<String, String>> updateAuthHeaders(
try {
String token =
"Basic " + Base64.getEncoder().encodeToString(userInfo.getBytes(ISO_8859_1));
updatedAuthHeaders.put(url.url().toURI(), ImmutableMap.of("Authorization", token));
updatedAuthHeaders.put(
url.url().toURI(), ImmutableMap.of("Authorization", ImmutableList.of(token)));
} catch (URISyntaxException e) {
// If the credentials extraction failed, we're letting bazel try without credentials.
}
Expand All @@ -159,7 +162,8 @@ public Map<URI, Map<String, String>> updateAuthHeaders(
if (firstAuthHeader.getValue() != null && !firstAuthHeader.getValue().isEmpty()) {
updatedAuthHeaders.put(
url.url().toURI(),
ImmutableMap.of(firstAuthHeader.getKey(), firstAuthHeader.getValue().get(0)));
ImmutableMap.of(
firstAuthHeader.getKey(), ImmutableList.of(firstAuthHeader.getValue().get(0))));
}
} catch (URISyntaxException | IOException e) {
// If the credentials extraction failed, we're letting bazel try without credentials.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ protected void checkInOutputDirectory(String operation, StarlarkPath path)
* authentication, adding those headers is enough; for other forms of authentication other
* measures might be necessary.
*/
private static ImmutableMap<URI, Map<String, String>> getAuthHeaders(Map<String, Dict<?, ?>> auth)
throws RepositoryFunctionException, EvalException {
ImmutableMap.Builder<URI, Map<String, String>> headers = new ImmutableMap.Builder<>();
private static ImmutableMap<URI, Map<String, List<String>>> getAuthHeaders(
Map<String, Dict<?, ?>> auth) throws RepositoryFunctionException, EvalException {
ImmutableMap.Builder<URI, Map<String, List<String>>> headers = new ImmutableMap.Builder<>();
for (Map.Entry<String, Dict<?, ?>> entry : auth.entrySet()) {
try {
URL url = new URL(entry.getKey());
Expand All @@ -174,7 +174,9 @@ private static ImmutableMap<URI, Map<String, String>> getAuthHeaders(Map<String,
url.toURI(),
ImmutableMap.of(
"Authorization",
"Basic " + Base64.getEncoder().encodeToString(credentials.getBytes(UTF_8))));
ImmutableList.of(
"Basic "
+ Base64.getEncoder().encodeToString(credentials.getBytes(UTF_8)))));
} else if ("pattern".equals(authMap.get("type"))) {
if (!authMap.containsKey("pattern")) {
throw Starlark.errorf(
Expand All @@ -201,7 +203,7 @@ private static ImmutableMap<URI, Map<String, String>> getAuthHeaders(Map<String,
result = result.replaceAll(demarcatedComponent, (String) authMap.get(component));
}

headers.put(url.toURI(), ImmutableMap.of("Authorization", result));
headers.put(url.toURI(), ImmutableMap.of("Authorization", ImmutableList.of(result)));
}
}
} catch (MalformedURLException e) {
Expand Down Expand Up @@ -445,7 +447,7 @@ public StructImpl download(
String integrity,
StarlarkThread thread)
throws RepositoryFunctionException, EvalException, InterruptedException {
ImmutableMap<URI, Map<String, String>> authHeaders =
ImmutableMap<URI, Map<String, List<String>>> authHeaders =
getAuthHeaders(getAuthContents(authUnchecked, "auth"));

ImmutableList<URL> urls =
Expand Down Expand Up @@ -634,7 +636,7 @@ public StructImpl downloadAndExtract(
Dict<?, ?> renameFiles, // <String, String> expected
StarlarkThread thread)
throws RepositoryFunctionException, InterruptedException, EvalException {
ImmutableMap<URI, Map<String, String>> authHeaders =
ImmutableMap<URI, Map<String, List<String>>> authHeaders =
getAuthHeaders(getAuthContents(auth, "auth"));

ImmutableList<URL> urls =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import build.bazel.remote.execution.v2.RequestMetadata;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.common.collect.Iterables;
import com.google.devtools.build.lib.bazel.repository.downloader.Checksum;
import com.google.devtools.build.lib.bazel.repository.downloader.Downloader;
import com.google.devtools.build.lib.bazel.repository.downloader.HashOutputStream;
Expand Down Expand Up @@ -106,7 +107,7 @@ public void close() {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, String>> authHeaders,
Map<URI, Map<String, List<String>>> authHeaders,
com.google.common.base.Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand Down Expand Up @@ -148,7 +149,7 @@ public void download(
static FetchBlobRequest newFetchBlobRequest(
String instanceName,
List<URL> urls,
Map<URI, Map<String, String>> authHeaders,
Map<URI, Map<String, List<String>>> authHeaders,
com.google.common.base.Optional<Checksum> checksum,
String canonicalId) {
FetchBlobRequest.Builder requestBuilder =
Expand Down Expand Up @@ -197,13 +198,17 @@ private OutputStream newOutputStream(
return out;
}

private static String authHeadersJson(Map<URI, Map<String, String>> authHeaders) {
private static String authHeadersJson(Map<URI, Map<String, List<String>>> authHeaders) {
Map<String, JsonObject> subObjects = new TreeMap<>();
for (Map.Entry<URI, Map<String, String>> entry : authHeaders.entrySet()) {
for (Map.Entry<URI, Map<String, List<String>>> entry : authHeaders.entrySet()) {
JsonObject subObject = new JsonObject();
Map<String, String> orderedHeaders = new TreeMap<>(entry.getValue());
for (Map.Entry<String, String> subEntry : orderedHeaders.entrySet()) {
subObject.addProperty(subEntry.getKey(), subEntry.getValue());
Map<String, List<String>> orderedHeaders = new TreeMap<>(entry.getValue());
for (Map.Entry<String, List<String>> subEntry : orderedHeaders.entrySet()) {
// TODO(yannic): Introduce incompatible flag for including all headers, not just the first.
String value = Iterables.getFirst(subEntry.getValue(), null);
if (value != null) {
subObject.addProperty(subEntry.getKey(), value);
}
}
subObjects.put(entry.getKey().toString(), subObject);
}
Expand Down

0 comments on commit a013ad4

Please sign in to comment.