Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: add support for SocketAddress types in ManagedChannelProvider #9076

Merged
merged 3 commits into from Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions api/src/main/java/io/grpc/ManagedChannelProvider.java
Expand Up @@ -17,6 +17,8 @@
package io.grpc;

import com.google.common.base.Preconditions;
import java.net.SocketAddress;
import java.util.Collection;

/**
* Provider of managed channels for transport agnostic consumption.
Expand Down Expand Up @@ -79,6 +81,11 @@ protected NewChannelBuilderResult newChannelBuilder(String target, ChannelCreden
return NewChannelBuilderResult.error("ChannelCredentials are unsupported");
}

/**
* Returns the {@link SocketAddress} types this ManagedChannelProvider supports.
*/
protected abstract Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes();

public static final class NewChannelBuilderResult {
private final ManagedChannelBuilder<?> channelBuilder;
private final String error;
Expand Down
36 changes: 36 additions & 0 deletions api/src/main/java/io/grpc/ManagedChannelRegistry.java
Expand Up @@ -18,7 +18,12 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.net.SocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashSet;
Expand Down Expand Up @@ -144,6 +149,28 @@ static List<Class<?>> getHardCodedClasses() {
}

ManagedChannelBuilder<?> newChannelBuilder(String target, ChannelCredentials creds) {
return newChannelBuilder(NameResolverRegistry.getDefaultRegistry(), target, creds);
}

@VisibleForTesting
ManagedChannelBuilder<?> newChannelBuilder(NameResolverRegistry nameResolverRegistry,
String target, ChannelCredentials creds) {
NameResolverProvider nameResolverProvider = null;
try {
URI uri = new URI(target);
nameResolverProvider = nameResolverRegistry.providers().get(uri.getScheme());
} catch (URISyntaxException ignore) {
// bad URI found, just ignore and continue
}
if (nameResolverProvider == null) {
nameResolverProvider = nameResolverRegistry.providers().get(
nameResolverRegistry.asFactory().getDefaultScheme());
}
Collection<Class<? extends SocketAddress>> nameResolverSocketAddressTypes
= (nameResolverProvider != null)
? nameResolverProvider.getProducedSocketAddressTypes() :
Collections.emptySet();

List<ManagedChannelProvider> providers = providers();
if (providers.isEmpty()) {
throw new ProviderNotFoundException("No functional channel service provider found. "
Expand All @@ -152,6 +179,15 @@ ManagedChannelBuilder<?> newChannelBuilder(String target, ChannelCredentials cre
}
StringBuilder error = new StringBuilder();
for (ManagedChannelProvider provider : providers()) {
Collection<Class<? extends SocketAddress>> channelProviderSocketAddressTypes
= provider.getSupportedSocketAddressTypes();
if (!channelProviderSocketAddressTypes.containsAll(nameResolverSocketAddressTypes)) {
error.append("; ");
error.append(provider.getClass().getName());
error.append(": does not support 1 or more of ");
error.append(Arrays.toString(nameResolverSocketAddressTypes.toArray()));
continue;
}
ManagedChannelProvider.NewChannelBuilderResult result
= provider.newChannelBuilder(target, creds);
if (result.getChannelBuilder() != null) {
Expand Down
14 changes: 14 additions & 0 deletions api/src/main/java/io/grpc/NameResolverProvider.java
Expand Up @@ -17,6 +17,10 @@
package io.grpc;

import io.grpc.NameResolver.Factory;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collection;
import java.util.Collections;

/**
* Provider of name resolvers for name agnostic consumption.
Expand Down Expand Up @@ -62,4 +66,14 @@ public abstract class NameResolverProvider extends NameResolver.Factory {
protected String getScheme() {
return getDefaultScheme();
}

/**
* Returns the {@link SocketAddress} types this provider's name-resolver is capable of producing.
* This enables selection of the appropriate {@link ManagedChannelProvider} for a channel.
*
* @return the {@link SocketAddress} types this provider's name-resolver is capable of producing.
*/
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
}
261 changes: 261 additions & 0 deletions api/src/test/java/io/grpc/ManagedChannelRegistryTest.java
Expand Up @@ -19,6 +19,12 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;

import com.google.common.collect.ImmutableSet;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -156,6 +162,256 @@ public void newChannelBuilder_noProvider() {
}
}

@Test
public void newChannelBuilder_usesScheme() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}

class SocketAddress2 extends SocketAddress {
}

nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
});
nameResolverRegistry.register(new BaseNameResolverProvider(true, 6, "sc2") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
fail("Should not be called");
throw new AssertionError();
}
});

ManagedChannelRegistry registry = new ManagedChannelRegistry();
registry.register(new BaseProvider(true, 5) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress2.class);
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
fail("Should not be called");
throw new AssertionError();
}
});
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}

final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(
registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs(
mcb);
}

@Test
public void newChannelBuilder_unsupportedSocketAddressTypes() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}

class SocketAddress2 extends SocketAddress {
}

nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return ImmutableSet.of(SocketAddress1.class, SocketAddress2.class);
}
});

ManagedChannelRegistry registry = new ManagedChannelRegistry();
registry.register(new BaseProvider(true, 5) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress2.class);
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
fail("Should not be called");
throw new AssertionError();
}
});
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}

registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
fail("Should not be called");
throw new AssertionError();
}
});
try {
registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds);
fail("expected exception");
} catch (ManagedChannelRegistry.ProviderNotFoundException ex) {
assertThat(ex).hasMessageThat().contains("does not support 1 or more of");
assertThat(ex).hasMessageThat().contains("SocketAddress1");
assertThat(ex).hasMessageThat().contains("SocketAddress2");
}
}

@Test
public void newChannelBuilder_emptySet_asDefault() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();

ManagedChannelRegistry registry = new ManagedChannelRegistry();
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}

final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.emptySet();
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(
registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs(
mcb);
}

@Test
public void newChannelBuilder_noSchemeUsesDefaultScheme() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}

nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
});

ManagedChannelRegistry registry = new ManagedChannelRegistry();
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}

final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(registry.newChannelBuilder(nameResolverRegistry, target, creds)).isSameInstanceAs(
mcb);
}

@Test
public void newChannelBuilder_badUri() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}

ManagedChannelRegistry registry = new ManagedChannelRegistry();

class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}

final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(
registry.newChannelBuilder(nameResolverRegistry, ":testing123", creds)).isSameInstanceAs(
mcb);
}

private static class BaseNameResolverProvider extends NameResolverProvider {
private final boolean isAvailable;
private final int priority;
private final String defaultScheme;

public BaseNameResolverProvider(boolean isAvailable, int priority, String defaultScheme) {
this.isAvailable = isAvailable;
this.priority = priority;
this.defaultScheme = defaultScheme;
}

@Override
public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) {
return null;
}

@Override
public String getDefaultScheme() {
return defaultScheme;
}

@Override
protected boolean isAvailable() {
return isAvailable;
}

@Override
protected int priority() {
return priority;
}
}

private static class BaseProvider extends ManagedChannelProvider {
private final boolean isAvailable;
private final int priority;
Expand Down Expand Up @@ -184,5 +440,10 @@ protected ManagedChannelBuilder<?> builderForAddress(String name, int port) {
protected ManagedChannelBuilder<?> builderForTarget(String target) {
throw new UnsupportedOperationException();
}

@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
}
}