Skip to content

Commit

Permalink
api: add support for SocketAddress types in ManagedChannelProvider (#…
Browse files Browse the repository at this point in the history
…9076)

* api: add support for SocketAddress types in ManagedChannelProvider
also add support for SocketAddress types in NameResolverProvider
Use scheme in target URI to select a NameRseolverProvider and get
that provider's supported SocketAddress types.
implement selection in ManagedChannelRegistry of appropriate
ManagedChannelProvider based on NameResolver's SocketAddress types
  • Loading branch information
sanjaypujare committed Apr 22, 2022
1 parent 8e65700 commit 538db03
Show file tree
Hide file tree
Showing 11 changed files with 381 additions and 0 deletions.
7 changes: 7 additions & 0 deletions api/src/main/java/io/grpc/ManagedChannelProvider.java
Expand Up @@ -17,6 +17,8 @@
package io.grpc;

import com.google.common.base.Preconditions;
import java.net.SocketAddress;
import java.util.Collection;

/**
* Provider of managed channels for transport agnostic consumption.
Expand Down Expand Up @@ -79,6 +81,11 @@ protected NewChannelBuilderResult newChannelBuilder(String target, ChannelCreden
return NewChannelBuilderResult.error("ChannelCredentials are unsupported");
}

/**
* Returns the {@link SocketAddress} types this ManagedChannelProvider supports.
*/
protected abstract Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes();

public static final class NewChannelBuilderResult {
private final ManagedChannelBuilder<?> channelBuilder;
private final String error;
Expand Down
36 changes: 36 additions & 0 deletions api/src/main/java/io/grpc/ManagedChannelRegistry.java
Expand Up @@ -18,7 +18,12 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.net.SocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashSet;
Expand Down Expand Up @@ -144,6 +149,28 @@ static List<Class<?>> getHardCodedClasses() {
}

ManagedChannelBuilder<?> newChannelBuilder(String target, ChannelCredentials creds) {
return newChannelBuilder(NameResolverRegistry.getDefaultRegistry(), target, creds);
}

@VisibleForTesting
ManagedChannelBuilder<?> newChannelBuilder(NameResolverRegistry nameResolverRegistry,
String target, ChannelCredentials creds) {
NameResolverProvider nameResolverProvider = null;
try {
URI uri = new URI(target);
nameResolverProvider = nameResolverRegistry.providers().get(uri.getScheme());
} catch (URISyntaxException ignore) {
// bad URI found, just ignore and continue
}
if (nameResolverProvider == null) {
nameResolverProvider = nameResolverRegistry.providers().get(
nameResolverRegistry.asFactory().getDefaultScheme());
}
Collection<Class<? extends SocketAddress>> nameResolverSocketAddressTypes
= (nameResolverProvider != null)
? nameResolverProvider.getProducedSocketAddressTypes() :
Collections.emptySet();

List<ManagedChannelProvider> providers = providers();
if (providers.isEmpty()) {
throw new ProviderNotFoundException("No functional channel service provider found. "
Expand All @@ -152,6 +179,15 @@ ManagedChannelBuilder<?> newChannelBuilder(String target, ChannelCredentials cre
}
StringBuilder error = new StringBuilder();
for (ManagedChannelProvider provider : providers()) {
Collection<Class<? extends SocketAddress>> channelProviderSocketAddressTypes
= provider.getSupportedSocketAddressTypes();
if (!channelProviderSocketAddressTypes.containsAll(nameResolverSocketAddressTypes)) {
error.append("; ");
error.append(provider.getClass().getName());
error.append(": does not support 1 or more of ");
error.append(Arrays.toString(nameResolverSocketAddressTypes.toArray()));
continue;
}
ManagedChannelProvider.NewChannelBuilderResult result
= provider.newChannelBuilder(target, creds);
if (result.getChannelBuilder() != null) {
Expand Down
14 changes: 14 additions & 0 deletions api/src/main/java/io/grpc/NameResolverProvider.java
Expand Up @@ -17,6 +17,10 @@
package io.grpc;

import io.grpc.NameResolver.Factory;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collection;
import java.util.Collections;

/**
* Provider of name resolvers for name agnostic consumption.
Expand Down Expand Up @@ -62,4 +66,14 @@ public abstract class NameResolverProvider extends NameResolver.Factory {
protected String getScheme() {
return getDefaultScheme();
}

/**
* Returns the {@link SocketAddress} types this provider's name-resolver is capable of producing.
* This enables selection of the appropriate {@link ManagedChannelProvider} for a channel.
*
* @return the {@link SocketAddress} types this provider's name-resolver is capable of producing.
*/
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
}
261 changes: 261 additions & 0 deletions api/src/test/java/io/grpc/ManagedChannelRegistryTest.java
Expand Up @@ -19,6 +19,12 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;

import com.google.common.collect.ImmutableSet;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -156,6 +162,256 @@ public void newChannelBuilder_noProvider() {
}
}

@Test
public void newChannelBuilder_usesScheme() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}

class SocketAddress2 extends SocketAddress {
}

nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
});
nameResolverRegistry.register(new BaseNameResolverProvider(true, 6, "sc2") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
fail("Should not be called");
throw new AssertionError();
}
});

ManagedChannelRegistry registry = new ManagedChannelRegistry();
registry.register(new BaseProvider(true, 5) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress2.class);
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
fail("Should not be called");
throw new AssertionError();
}
});
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}

final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(
registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs(
mcb);
}

@Test
public void newChannelBuilder_unsupportedSocketAddressTypes() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}

class SocketAddress2 extends SocketAddress {
}

nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return ImmutableSet.of(SocketAddress1.class, SocketAddress2.class);
}
});

ManagedChannelRegistry registry = new ManagedChannelRegistry();
registry.register(new BaseProvider(true, 5) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress2.class);
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
fail("Should not be called");
throw new AssertionError();
}
});
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}

registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
fail("Should not be called");
throw new AssertionError();
}
});
try {
registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds);
fail("expected exception");
} catch (ManagedChannelRegistry.ProviderNotFoundException ex) {
assertThat(ex).hasMessageThat().contains("does not support 1 or more of");
assertThat(ex).hasMessageThat().contains("SocketAddress1");
assertThat(ex).hasMessageThat().contains("SocketAddress2");
}
}

@Test
public void newChannelBuilder_emptySet_asDefault() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();

ManagedChannelRegistry registry = new ManagedChannelRegistry();
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}

final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.emptySet();
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(
registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs(
mcb);
}

@Test
public void newChannelBuilder_noSchemeUsesDefaultScheme() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}

nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") {
@Override
protected Collection<Class<? extends SocketAddress>> getProducedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}
});

ManagedChannelRegistry registry = new ManagedChannelRegistry();
class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}

final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(registry.newChannelBuilder(nameResolverRegistry, target, creds)).isSameInstanceAs(
mcb);
}

@Test
public void newChannelBuilder_badUri() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
class SocketAddress1 extends SocketAddress {
}

ManagedChannelRegistry registry = new ManagedChannelRegistry();

class MockChannelBuilder extends ForwardingChannelBuilder<MockChannelBuilder> {
@Override public ManagedChannelBuilder<?> delegate() {
throw new UnsupportedOperationException();
}
}

final ManagedChannelBuilder<?> mcb = new MockChannelBuilder();
registry.register(new BaseProvider(true, 4) {
@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(SocketAddress1.class);
}

@Override
public NewChannelBuilderResult newChannelBuilder(
String passedTarget, ChannelCredentials passedCreds) {
return NewChannelBuilderResult.channelBuilder(mcb);
}
});
assertThat(
registry.newChannelBuilder(nameResolverRegistry, ":testing123", creds)).isSameInstanceAs(
mcb);
}

private static class BaseNameResolverProvider extends NameResolverProvider {
private final boolean isAvailable;
private final int priority;
private final String defaultScheme;

public BaseNameResolverProvider(boolean isAvailable, int priority, String defaultScheme) {
this.isAvailable = isAvailable;
this.priority = priority;
this.defaultScheme = defaultScheme;
}

@Override
public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) {
return null;
}

@Override
public String getDefaultScheme() {
return defaultScheme;
}

@Override
protected boolean isAvailable() {
return isAvailable;
}

@Override
protected int priority() {
return priority;
}
}

private static class BaseProvider extends ManagedChannelProvider {
private final boolean isAvailable;
private final int priority;
Expand Down Expand Up @@ -184,5 +440,10 @@ protected ManagedChannelBuilder<?> builderForAddress(String name, int port) {
protected ManagedChannelBuilder<?> builderForTarget(String target) {
throw new UnsupportedOperationException();
}

@Override
protected Collection<Class<? extends SocketAddress>> getSupportedSocketAddressTypes() {
return Collections.singleton(InetSocketAddress.class);
}
}
}

0 comments on commit 538db03

Please sign in to comment.