Skip to content

Commit

Permalink
[remote/downloader] Migrate Downloader to take Credentials
Browse files Browse the repository at this point in the history
Progress on bazelbuild#15856
  • Loading branch information
Yannic committed Oct 28, 2022
1 parent e5a7389 commit 4c861fc
Show file tree
Hide file tree
Showing 16 changed files with 123 additions and 39 deletions.
Expand Up @@ -6,6 +6,7 @@ filegroup(
name = "srcs",
srcs = glob(["**"]) + [
"//src/main/java/com/google/devtools/build/lib/authandtls/credentialhelper:srcs",
"//src/main/java/com/google/devtools/build/lib/authandtls/staticcredentials:srcs",
],
visibility = ["//src:__subpackages__"],
)
Expand Down
@@ -0,0 +1,18 @@
load("@rules_java//java:defs.bzl", "java_library")

package(default_visibility = ["//src:__subpackages__"])

filegroup(
name = "srcs",
srcs = glob(["**"]),
visibility = ["//src:__subpackages__"],
)

java_library(
name = "staticcredentials",
srcs = glob(["*.java"]),
deps = [
"//third_party:auth",
"//third_party:guava",
],
)
@@ -0,0 +1,51 @@
package com.google.devtools.build.lib.authandtls.staticcredentials;

import com.google.auth.Credentials;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.net.URI;
import java.util.List;
import java.util.Map;

/** Implementation of {@link Credentials} which provides a static set of credentials. */
public final class StaticCredentials extends Credentials {
private final ImmutableMap<URI, Map<String, List<String>>> credentials;

public StaticCredentials(Map<URI, Map<String, List<String>>> credentials) {
Preconditions.checkNotNull(credentials);

this.credentials = ImmutableMap.copyOf(credentials);
}

public Map<URI, Map<String, List<String>>> getMapForMigration() {
return credentials;
}

@Override
public String getAuthenticationType() {
return "static";
}

@Override
public Map<String, List<String>> getRequestMetadata(URI uri) throws IOException {
Preconditions.checkNotNull(uri);

return credentials.getOrDefault(uri, ImmutableMap.of());
}

@Override
public boolean hasRequestMetadata() {
return true;
}

@Override
public boolean hasRequestMetadataOnly() {
return true;
}

@Override
public void refresh() {
// Can't refresh static credentials.
}
}
Expand Up @@ -14,6 +14,7 @@ java_library(
deps = [
"//src/main/java/com/google/devtools/build/lib/analysis:blaze_version_info",
"//src/main/java/com/google/devtools/build/lib/authandtls",
"//src/main/java/com/google/devtools/build/lib/authandtls/staticcredentials",
"//src/main/java/com/google/devtools/build/lib/bazel/repository/cache",
"//src/main/java/com/google/devtools/build/lib/bazel/repository/cache:events",
"//src/main/java/com/google/devtools/build/lib/buildeventstream",
Expand Down
Expand Up @@ -14,6 +14,7 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.base.Optional;
import com.google.devtools.build.lib.events.ExtendedEventHandler;
import com.google.devtools.build.lib.vfs.Path;
Expand Down Expand Up @@ -47,7 +48,7 @@ public void setDelegate(@Nullable Downloader delegate) {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand All @@ -60,6 +61,6 @@ public void download(
downloader = delegate;
}
downloader.download(
urls, authHeaders, checksum, canonicalId, destination, eventHandler, clientEnv, type);
urls, credentials, checksum, canonicalId, destination, eventHandler, clientEnv, type);
}
}
Expand Up @@ -24,6 +24,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.devtools.build.lib.authandtls.staticcredentials.StaticCredentials;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCacheHitEvent;
Expand Down Expand Up @@ -256,7 +257,7 @@ public Path download(
try {
downloader.download(
rewrittenUrls,
rewrittenAuthHeaders,
new StaticCredentials(rewrittenAuthHeaders),
checksum,
canonicalId,
destination,
Expand Down Expand Up @@ -337,7 +338,7 @@ public byte[] downloadAndReadOneUrl(
for (int attempt = 0; attempt <= retries; ++attempt) {
try {
return httpDownloader.downloadAndReadOneUrl(
rewrittenUrls.get(0), authHeaders, eventHandler, clientEnv);
rewrittenUrls.get(0), new StaticCredentials(authHeaders), eventHandler, clientEnv);
} catch (ContentLengthMismatchException e) {
if (attempt == retries) {
throw e;
Expand Down
Expand Up @@ -14,6 +14,7 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.base.Optional;
import com.google.devtools.build.lib.events.ExtendedEventHandler;
import com.google.devtools.build.lib.vfs.Path;
Expand All @@ -33,7 +34,7 @@ public interface Downloader {
* caller is responsible for cleaning up outputs of failed downloads.
*
* @param urls list of mirror URLs with identical content
* @param authHeaders map of authentication headers per URL
* @param credentials credentials to use when connecting to URLs
* @param checksum valid checksum which is checked, or absent to disable
* @param output path to the destination file to write
* @param type extension, e.g. "tar.gz" to force on downloaded filename, or empty to not do this
Expand All @@ -42,7 +43,7 @@ public interface Downloader {
*/
void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<Checksum> checksum,
String canonicalId,
Path output,
Expand Down
Expand Up @@ -14,13 +14,15 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.devtools.build.lib.analysis.BlazeVersionInfo;
import com.google.devtools.build.lib.authandtls.staticcredentials.StaticCredentials;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import com.google.devtools.build.lib.events.Event;
import com.google.devtools.build.lib.events.EventHandler;
Expand Down Expand Up @@ -74,7 +76,7 @@ final class HttpConnectorMultiplexer {
}

public HttpStream connect(URL url, Optional<Checksum> checksum) throws IOException {
return connect(url, checksum, ImmutableMap.of(), Optional.absent());
return connect(url, checksum, new StaticCredentials(ImmutableMap.of()), Optional.absent());
}

/**
Expand All @@ -87,7 +89,7 @@ public HttpStream connect(URL url, Optional<Checksum> checksum) throws IOExcepti
*
* @param url the URL to conenct to. can be: file, http, or https
* @param checksum checksum lazily checked on entire payload, or empty to disable
* @param authHeaders the authentication headers
* @param credentials the credentials
* @param type extension, e.g. "tar.gz" to force on downloaded filename, or empty to not do this
* @return an {@link InputStream} of response payload
* @throws IOException if all mirrors are down and contains suppressed exception of each attempt
Expand All @@ -97,15 +99,15 @@ public HttpStream connect(URL url, Optional<Checksum> checksum) throws IOExcepti
public HttpStream connect(
URL url,
Optional<Checksum> checksum,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<String> type)
throws IOException {
Preconditions.checkArgument(HttpUtils.isUrlSupportedByDownloader(url));
if (Thread.interrupted()) {
throw new InterruptedIOException();
}
Function<URL, ImmutableMap<String, List<String>>> headerFunction =
getHeaderFunction(REQUEST_HEADERS, authHeaders);
getHeaderFunction(REQUEST_HEADERS, credentials);
URLConnection connection = connector.connect(url, headerFunction);
return httpStreamFactory.create(
connection,
Expand All @@ -128,20 +130,20 @@ public HttpStream connect(
@VisibleForTesting
static Function<URL, ImmutableMap<String, List<String>>> getHeaderFunction(
Map<String, List<String>> baseHeaders,
Map<URI, Map<String, List<String>>> additionalHeaders) {
Credentials credentials) {
Preconditions.checkNotNull(baseHeaders);
Preconditions.checkNotNull(credentials);

return url -> {
ImmutableMap<String, List<String>> headers = ImmutableMap.copyOf(baseHeaders);
Map<String, List<String>> headers = new HashMap<>(baseHeaders);
try {
if (additionalHeaders.containsKey(url.toURI())) {
Map<String, List<String>> newHeaders = new HashMap<>(headers);
newHeaders.putAll(additionalHeaders.get(url.toURI()));
headers = ImmutableMap.copyOf(newHeaders);
}
} catch (URISyntaxException e) {
// If we can't convert the URL to a URI (because it is syntactically malformed), still try
// to do the connection, not adding authentication information as we cannot look it up.
headers.putAll(credentials.getRequestMetadata(url.toURI()));
} catch (URISyntaxException | IOException e) {
// If we can't convert the URL to a URI (because it is syntactically malformed), or fetching
// credentials fails for any other reason, still try to do the connection, not adding
// authentication information as we cannot look it up.
}
return headers;
return ImmutableMap.copyOf(headers);
};
}
}
Expand Up @@ -14,6 +14,7 @@

package com.google.devtools.build.lib.bazel.repository.downloader;

import com.google.auth.Credentials;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
Expand Down Expand Up @@ -63,7 +64,7 @@ public void setTimeoutScaling(float timeoutScaling) {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand All @@ -82,8 +83,8 @@ public void download(
for (URL url : urls) {
SEMAPHORE.acquire();

try (HttpStream payload = multiplexer.connect(url, checksum, authHeaders, type);
OutputStream out = destination.getOutputStream()) {
try (HttpStream payload = multiplexer.connect(url, checksum, credentials, type);
OutputStream out = destination.getOutputStream()) {
try {
ByteStreams.copy(payload, out);
} catch (SocketTimeoutException e) {
Expand Down Expand Up @@ -132,7 +133,7 @@ public void download(
/** Downloads the contents of one URL and reads it into a byte array. */
public byte[] downloadAndReadOneUrl(
URL url,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
ExtendedEventHandler eventHandler,
Map<String, String> clientEnv)
throws IOException, InterruptedException {
Expand All @@ -141,7 +142,7 @@ public byte[] downloadAndReadOneUrl(
ByteArrayOutputStream out = new ByteArrayOutputStream();
SEMAPHORE.acquire();
try (HttpStream payload =
multiplexer.connect(url, Optional.absent(), authHeaders, Optional.absent())) {
multiplexer.connect(url, Optional.absent(), credentials, Optional.absent())) {
ByteStreams.copy(payload, out);
} catch (SocketTimeoutException e) {
// SocketTimeoutExceptions are InterruptedIOExceptions; however they do not signify
Expand Down
Expand Up @@ -14,6 +14,7 @@ java_library(
name = "downloader",
srcs = glob(["*.java"]),
deps = [
"//src/main/java/com/google/devtools/build/lib/authandtls/staticcredentials",
"//src/main/java/com/google/devtools/build/lib/bazel/repository/downloader",
"//src/main/java/com/google/devtools/build/lib/events",
"//src/main/java/com/google/devtools/build/lib/remote:ReferenceCountedChannel",
Expand All @@ -22,6 +23,7 @@ java_library(
"//src/main/java/com/google/devtools/build/lib/remote/options",
"//src/main/java/com/google/devtools/build/lib/remote/util",
"//src/main/java/com/google/devtools/build/lib/vfs",
"//third_party:auth",
"//third_party:guava",
"//third_party:jsr305",
"//third_party/grpc-java:grpc-jar",
Expand Down
Expand Up @@ -21,6 +21,7 @@
import build.bazel.remote.asset.v1.Qualifier;
import build.bazel.remote.execution.v2.Digest;
import build.bazel.remote.execution.v2.RequestMetadata;
import com.google.auth.Credentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.devtools.build.lib.bazel.repository.downloader.Checksum;
Expand All @@ -41,12 +42,10 @@
import io.grpc.StatusRuntimeException;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.net.URL;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -112,7 +111,7 @@ public void close() {
@Override
public void download(
List<URL> urls,
Map<URI, Map<String, List<String>>> authHeaders,
Credentials credentials,
com.google.common.base.Optional<Checksum> checksum,
String canonicalId,
Path destination,
Expand Down Expand Up @@ -156,7 +155,7 @@ public void download(
eventHandler.handle(
Event.warn("Remote Cache: " + Utils.grpcAwareErrorMessage(e, verboseFailures)));
fallbackDownloader.download(
urls, authHeaders, checksum, canonicalId, destination, eventHandler, clientEnv, type);
urls, credentials, checksum, canonicalId, destination, eventHandler, clientEnv, type);
}
}

Expand Down
Expand Up @@ -17,6 +17,7 @@ java_library(
srcs = glob(["*.java"]),
deps = [
"//src/main/java/com/google/devtools/build/lib/authandtls",
"//src/main/java/com/google/devtools/build/lib/authandtls/staticcredentials",
"//src/main/java/com/google/devtools/build/lib/bazel/repository/cache",
"//src/main/java/com/google/devtools/build/lib/bazel/repository/downloader",
"//src/main/java/com/google/devtools/build/lib/events",
Expand Down
Expand Up @@ -32,6 +32,7 @@
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.devtools.build.lib.authandtls.staticcredentials.StaticCredentials;
import com.google.devtools.build.lib.bazel.repository.cache.RepositoryCache.KeyType;
import com.google.devtools.build.lib.bazel.repository.downloader.RetryingInputStream.Reconnector;
import com.google.devtools.build.lib.events.EventHandler;
Expand Down Expand Up @@ -163,7 +164,8 @@ public void testHeaderComputationFunction() throws Exception {
ImmutableMap.of("Authentication", ImmutableList.of("Zm9vOmZvb3NlY3JldA==")));

Function<URL, ImmutableMap<String, List<String>>> headerFunction =
HttpConnectorMultiplexer.getHeaderFunction(baseHeaders, additionalHeaders);
HttpConnectorMultiplexer.getHeaderFunction(
baseHeaders, new StaticCredentials(additionalHeaders));

// Unrelated URL
assertThat(headerFunction.apply(new URL("http://example.org/some/path/file.txt")))
Expand Down Expand Up @@ -215,7 +217,8 @@ public void testHeaderComputationFunction() throws Exception {
ImmutableMap<String, List<String>> annonAuth =
ImmutableMap.of("Authentication", ImmutableList.of("YW5vbnltb3VzOmZvb0BleGFtcGxlLm9yZw=="));
Function<URL, ImmutableMap<String, List<String>>> combinedHeaders =
HttpConnectorMultiplexer.getHeaderFunction(annonAuth, additionalHeaders);
HttpConnectorMultiplexer.getHeaderFunction(
annonAuth, new StaticCredentials(additionalHeaders));
assertThat(combinedHeaders.apply(new URL("http://hosting.example.com/user/foo/file.txt")))
.containsExactly("Authentication", ImmutableList.of("Zm9vOmZvb3NlY3JldA=="));
assertThat(combinedHeaders.apply(new URL("http://unreleated.example.org/user/foo/file.txt")))
Expand Down

0 comments on commit 4c861fc

Please sign in to comment.