Skip to content

Commit

Permalink
xds: Plumb locality in xds_cluster_impl and weighted_target
Browse files Browse the repository at this point in the history
As part of gRFC A78:

> To support the locality label in the WRR metrics, we will extend the
> `weighted_target` LB policy (see A28) to define a resolver attribute
> that indicates the name of its child. This attribute will be passed
> down to each of its children with the appropriate value, so that any
> LB policy that sits underneath the `weighted_target` policy will be
> able to use it.

xds_cluster_impl is involved because it uses the child names in the
AddressFilter, which must match the names used by weighted_target.
Instead of using Locality.toString() in multiple policies and assuming
the policies agree, we now have xds_cluster_impl decide the locality's
name and pass it down explicitly. This allows us to change the name
format to match gRFC A78:

> If locality information is available, the value of this label will be
> of the form `{region="${REGION}", zone="${ZONE}",
> sub_zone="${SUB_ZONE}"}`, where `${REGION}`, `${ZONE}`, and
> `${SUB_ZONE}` are replaced with the actual values. If no locality
> information is available, the label will be set to the empty string.
  • Loading branch information
ejona86 committed May 3, 2024
1 parent 13a9290 commit 077dcbf
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 48 deletions.
15 changes: 13 additions & 2 deletions xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java
Expand Up @@ -76,6 +76,8 @@ final class ClusterImplLoadBalancer extends LoadBalancer {

private static final Attributes.Key<ClusterLocalityStats> ATTR_CLUSTER_LOCALITY_STATS =
Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.clusterLocalityStats");
private static final Attributes.Key<String> ATTR_CLUSTER_LOCALITY_NAME =
Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.clusterLocalityName");

private final XdsLogger logger;
private final Helper helper;
Expand Down Expand Up @@ -209,20 +211,25 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) {
List<EquivalentAddressGroup> addresses = withAdditionalAttributes(args.getAddresses());
Locality locality = args.getAddresses().get(0).getAttributes().get(
InternalXdsAttributes.ATTR_LOCALITY); // all addresses should be in the same locality
String localityName = args.getAddresses().get(0).getAttributes().get(
InternalXdsAttributes.ATTR_LOCALITY_NAME);
// Endpoint addresses resolved by ClusterResolverLoadBalancer should always contain
// attributes with its locality, including endpoints in LOGICAL_DNS clusters.
// In case of not (which really shouldn't), loads are aggregated under an empty locality.
if (locality == null) {
locality = Locality.create("", "", "");
localityName = "";
}
final ClusterLocalityStats localityStats =
(lrsServerInfo == null)
? null
: xdsClient.addClusterLocalityStats(lrsServerInfo, cluster,
edsServiceName, locality);

Attributes attrs = args.getAttributes().toBuilder().set(
ATTR_CLUSTER_LOCALITY_STATS, localityStats).build();
Attributes attrs = args.getAttributes().toBuilder()
.set(ATTR_CLUSTER_LOCALITY_STATS, localityStats)
.set(ATTR_CLUSTER_LOCALITY_NAME, localityName)
.build();
args = args.toBuilder().setAddresses(addresses).setAttributes(attrs).build();
final Subchannel subchannel = delegate().createSubchannel(args);

Expand Down Expand Up @@ -344,6 +351,10 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
final ClusterLocalityStats stats =
result.getSubchannel().getAttributes().get(ATTR_CLUSTER_LOCALITY_STATS);
if (stats != null) {
String localityName =
result.getSubchannel().getAttributes().get(ATTR_CLUSTER_LOCALITY_NAME);
args.getPickDetailsConsumer().addOptionalLabel("grpc.lb.locality", localityName);

ClientStreamTracer.Factory tracerFactory = new CountingStreamTracerFactory(
stats, inFlights, result.getStreamTracerFactory());
ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance()
Expand Down
20 changes: 13 additions & 7 deletions xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java
Expand Up @@ -412,17 +412,18 @@ public void run() {
if (endpoint.loadBalancingWeight() != 0) {
weight *= endpoint.loadBalancingWeight();
}
String localityName = localityName(locality);
Attributes attr =
endpoint.eag().getAttributes().toBuilder()
.set(InternalXdsAttributes.ATTR_LOCALITY, locality)
.set(InternalXdsAttributes.ATTR_LOCALITY_NAME, localityName)
.set(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT,
localityLbInfo.localityWeight())
.set(InternalXdsAttributes.ATTR_SERVER_WEIGHT, weight)
.build();
EquivalentAddressGroup eag = new EquivalentAddressGroup(
endpoint.eag().getAddresses(), attr);
eag = AddressFilter.setPathFilter(
eag, Arrays.asList(priorityName, localityName(locality)));
eag = AddressFilter.setPathFilter(eag, Arrays.asList(priorityName, localityName));
addresses.add(eag);
}
}
Expand Down Expand Up @@ -612,11 +613,13 @@ public void run() {
for (EquivalentAddressGroup eag : resolutionResult.getAddresses()) {
// No weight attribute is attached, all endpoint-level LB policy should be able
// to handle such it.
Attributes attr = eag.getAttributes().toBuilder().set(
InternalXdsAttributes.ATTR_LOCALITY, LOGICAL_DNS_CLUSTER_LOCALITY).build();
String localityName = localityName(LOGICAL_DNS_CLUSTER_LOCALITY);
Attributes attr = eag.getAttributes().toBuilder()
.set(InternalXdsAttributes.ATTR_LOCALITY, LOGICAL_DNS_CLUSTER_LOCALITY)
.set(InternalXdsAttributes.ATTR_LOCALITY_NAME, localityName)
.build();
eag = new EquivalentAddressGroup(eag.getAddresses(), attr);
eag = AddressFilter.setPathFilter(
eag, Arrays.asList(priorityName, LOGICAL_DNS_CLUSTER_LOCALITY.toString()));
eag = AddressFilter.setPathFilter(eag, Arrays.asList(priorityName, localityName));
addresses.add(eag);
}
PriorityChildConfig priorityChildConfig = generateDnsBasedPriorityChildConfig(
Expand Down Expand Up @@ -844,6 +847,9 @@ private static String priorityName(String cluster, int priority) {
* across all localities in all clusters.
*/
private static String localityName(Locality locality) {
return locality.toString();
return "{region=\"" + locality.region()
+ "\", zone=\"" + locality.zone()
+ "\", sub_zone=\"" + locality.subZone()
+ "\"}";
}
}
7 changes: 7 additions & 0 deletions xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java
Expand Up @@ -77,6 +77,13 @@ public final class InternalXdsAttributes {
static final Attributes.Key<Locality> ATTR_LOCALITY =
Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.locality");

/**
* The name of the locality that this EquivalentAddressGroup is in.
*/
@EquivalentAddressGroup.Attr
static final Attributes.Key<String> ATTR_LOCALITY_NAME =
Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.localityName");

/**
* Endpoint weight for load balancing purposes.
*/
Expand Down
6 changes: 6 additions & 0 deletions xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java
Expand Up @@ -23,6 +23,7 @@
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;

import com.google.common.collect.ImmutableMap;
import io.grpc.Attributes;
import io.grpc.ConnectivityState;
import io.grpc.InternalLogId;
import io.grpc.LoadBalancer;
Expand All @@ -42,6 +43,8 @@

/** Load balancer for weighted_target policy. */
final class WeightedTargetLoadBalancer extends LoadBalancer {
public static final Attributes.Key<String> CHILD_NAME =
Attributes.Key.create("io.grpc.xds.WeightedTargetLoadBalancer.CHILD_NAME");

private final XdsLogger logger;
private final Map<String, GracefulSwitchLoadBalancer> childBalancers = new HashMap<>();
Expand Down Expand Up @@ -95,6 +98,9 @@ public Status acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresse
resolvedAddresses.toBuilder()
.setAddresses(AddressFilter.filter(resolvedAddresses.getAddresses(), targetName))
.setLoadBalancingPolicyConfig(targets.get(targetName).policySelection.getConfig())
.setAttributes(resolvedAddresses.getAttributes().toBuilder()
.set(CHILD_NAME, targetName)
.build())
.build());
}

Expand Down
9 changes: 4 additions & 5 deletions xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java
Expand Up @@ -31,7 +31,6 @@
import io.grpc.util.GracefulSwitchLoadBalancer;
import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection;
import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig;
import io.grpc.xds.client.Locality;
import io.grpc.xds.client.XdsLogger;
import io.grpc.xds.client.XdsLogger.XdsLogLevel;
import java.util.HashMap;
Expand Down Expand Up @@ -73,10 +72,10 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
= (WrrLocalityConfig) resolvedAddresses.getLoadBalancingPolicyConfig();

// A map of locality weights is built up from the locality weight attributes in each address.
Map<Locality, Integer> localityWeights = new HashMap<>();
Map<String, Integer> localityWeights = new HashMap<>();
for (EquivalentAddressGroup eag : resolvedAddresses.getAddresses()) {
Attributes eagAttrs = eag.getAttributes();
Locality locality = eagAttrs.get(InternalXdsAttributes.ATTR_LOCALITY);
String locality = eagAttrs.get(InternalXdsAttributes.ATTR_LOCALITY_NAME);
Integer localityWeight = eagAttrs.get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT);

if (locality == null) {
Expand Down Expand Up @@ -106,8 +105,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
// Weighted target LB expects a WeightedPolicySelection for each locality as it will create a
// child LB for each.
Map<String, WeightedPolicySelection> weightedPolicySelections = new HashMap<>();
for (Locality locality : localityWeights.keySet()) {
weightedPolicySelections.put(locality.toString(),
for (String locality : localityWeights.keySet()) {
weightedPolicySelections.put(locality,
new WeightedPolicySelection(localityWeights.get(locality),
wrrLocalityConfig.childPolicy));
}
Expand Down
62 changes: 50 additions & 12 deletions xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java
Expand Up @@ -19,19 +19,22 @@
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.github.xds.data.orca.v3.OrcaLoadReport;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.grpc.Attributes;
import io.grpc.CallOptions;
import io.grpc.ClientStreamTracer;
import io.grpc.ConnectivityState;
import io.grpc.EquivalentAddressGroup;
import io.grpc.InsecureChannelCredentials;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.CreateSubchannelArgs;
import io.grpc.LoadBalancer.Helper;
import io.grpc.LoadBalancer.PickDetailsConsumer;
import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.LoadBalancer.ResolvedAddresses;
Expand All @@ -45,8 +48,10 @@
import io.grpc.SynchronizationContext;
import io.grpc.internal.FakeClock;
import io.grpc.internal.ObjectPool;
import io.grpc.internal.PickSubchannelArgsImpl;
import io.grpc.internal.ServiceConfigUtil.PolicySelection;
import io.grpc.protobuf.ProtoUtils;
import io.grpc.testing.TestMethodDescriptors;
import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig;
import io.grpc.xds.Endpoints.DropOverload;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
Expand Down Expand Up @@ -141,6 +146,9 @@ public AtomicLong getOrCreate(String cluster, @Nullable String edsServiceName) {
}
};
private final Helper helper = new FakeLbHelper();
private PickSubchannelArgs pickSubchannelArgs = new PickSubchannelArgsImpl(
TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT,
new PickDetailsConsumer() {});
@Mock
private ThreadSafeRandom mockRandom;
private int xdsClientRefs;
Expand Down Expand Up @@ -218,7 +226,7 @@ public void handleResolvedAddresses_childPolicyChanges() {
public void nameResolutionError_beforeChildPolicyInstantiated_returnErrorPickerToUpstream() {
loadBalancer.handleNameResolutionError(Status.UNIMPLEMENTED.withDescription("not found"));
assertThat(currentState).isEqualTo(ConnectivityState.TRANSIENT_FAILURE);
PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class));
PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs);
assertThat(result.getStatus().isOk()).isFalse();
assertThat(result.getStatus().getCode()).isEqualTo(Code.UNIMPLEMENTED);
assertThat(result.getStatus().getDescription()).isEqualTo("not found");
Expand All @@ -243,6 +251,32 @@ public void nameResolutionError_afterChildPolicyInstantiated_propagateToDownstre
.isEqualTo("cannot reach server");
}

@Test
public void pick_addsLocalityLabel() {
LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider();
WeightedTargetConfig weightedTargetConfig =
buildWeightedTargetConfig(ImmutableMap.of(locality, 10));
ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO,
null, Collections.<DropOverload>emptyList(),
new PolicySelection(weightedTargetProvider, weightedTargetConfig), null);
EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality);
deliverAddressesAndConfig(Collections.singletonList(endpoint), config);
FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers);
Subchannel subchannel = leafBalancer.helper.createSubchannel(
CreateSubchannelArgs.newBuilder().setAddresses(leafBalancer.addresses).build());
leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY);
assertThat(currentState).isEqualTo(ConnectivityState.READY);

PickDetailsConsumer detailsConsumer = mock(PickDetailsConsumer.class);
pickSubchannelArgs = new PickSubchannelArgsImpl(
TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, detailsConsumer);
PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs);
assertThat(result.getStatus().isOk()).isTrue();
// The value will be determined by the parent policy, so can be different than the value used in
// makeAddress() for the test.
verify(detailsConsumer).addOptionalLabel("grpc.lb.locality", locality.toString());
}

@Test
public void recordLoadStats() {
LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider();
Expand All @@ -258,7 +292,7 @@ public void recordLoadStats() {
CreateSubchannelArgs.newBuilder().setAddresses(leafBalancer.addresses).build());
leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY);
assertThat(currentState).isEqualTo(ConnectivityState.READY);
PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class));
PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs);
assertThat(result.getStatus().isOk()).isTrue();
ClientStreamTracer streamTracer1 = result.getStreamTracerFactory().newClientStreamTracer(
ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); // first RPC call
Expand Down Expand Up @@ -347,7 +381,7 @@ public void dropRpcsWithRespectToLbConfigDropCategories() {
CreateSubchannelArgs.newBuilder().setAddresses(leafBalancer.addresses).build());
leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY);
assertThat(currentState).isEqualTo(ConnectivityState.READY);
PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class));
PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs);
assertThat(result.getStatus().isOk()).isFalse();
assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE);
assertThat(result.getStatus().getDescription()).isEqualTo("Dropped: throttle");
Expand All @@ -373,7 +407,7 @@ public void dropRpcsWithRespectToLbConfigDropCategories() {
.build())
.setLoadBalancingPolicyConfig(config)
.build());
result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class));
result = currentPicker.pickSubchannel(pickSubchannelArgs);
assertThat(result.getStatus().isOk()).isFalse();
assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE);
assertThat(result.getStatus().getDescription()).isEqualTo("Dropped: lb");
Expand All @@ -386,7 +420,7 @@ public void dropRpcsWithRespectToLbConfigDropCategories() {
.isEqualTo(1L);
assertThat(clusterStats.totalDroppedRequests()).isEqualTo(1L);

result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class));
result = currentPicker.pickSubchannel(pickSubchannelArgs);
assertThat(result.getStatus().isOk()).isTrue();
}

Expand Down Expand Up @@ -423,7 +457,7 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu
leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY);
assertThat(currentState).isEqualTo(ConnectivityState.READY);
for (int i = 0; i < maxConcurrentRequests; i++) {
PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class));
PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs);
assertThat(result.getStatus().isOk()).isTrue();
ClientStreamTracer.Factory streamTracerFactory = result.getStreamTracerFactory();
streamTracerFactory.newClientStreamTracer(
Expand All @@ -434,7 +468,7 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu
assertThat(clusterStats.clusterServiceName()).isEqualTo(EDS_SERVICE_NAME);
assertThat(clusterStats.totalDroppedRequests()).isEqualTo(0L);

PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class));
PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs);
clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER));
assertThat(clusterStats.clusterServiceName()).isEqualTo(EDS_SERVICE_NAME);
if (enableCircuitBreaking) {
Expand All @@ -455,15 +489,15 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu
new PolicySelection(weightedTargetProvider, weightedTargetConfig), null);
deliverAddressesAndConfig(Collections.singletonList(endpoint), config);

result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class));
result = currentPicker.pickSubchannel(pickSubchannelArgs);
assertThat(result.getStatus().isOk()).isTrue();
result.getStreamTracerFactory().newClientStreamTracer(
ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); // 101th request
clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER));
assertThat(clusterStats.clusterServiceName()).isEqualTo(EDS_SERVICE_NAME);
assertThat(clusterStats.totalDroppedRequests()).isEqualTo(0L);

result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); // 102th request
result = currentPicker.pickSubchannel(pickSubchannelArgs); // 102th request
clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER));
assertThat(clusterStats.clusterServiceName()).isEqualTo(EDS_SERVICE_NAME);
if (enableCircuitBreaking) {
Expand Down Expand Up @@ -511,7 +545,7 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue(
leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY);
assertThat(currentState).isEqualTo(ConnectivityState.READY);
for (int i = 0; i < ClusterImplLoadBalancer.DEFAULT_PER_CLUSTER_MAX_CONCURRENT_REQUESTS; i++) {
PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class));
PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs);
assertThat(result.getStatus().isOk()).isTrue();
ClientStreamTracer.Factory streamTracerFactory = result.getStreamTracerFactory();
streamTracerFactory.newClientStreamTracer(
Expand All @@ -522,7 +556,7 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue(
assertThat(clusterStats.clusterServiceName()).isEqualTo(EDS_SERVICE_NAME);
assertThat(clusterStats.totalDroppedRequests()).isEqualTo(0L);

PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class));
PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs);
clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER));
assertThat(clusterStats.clusterServiceName()).isEqualTo(EDS_SERVICE_NAME);
if (enableCircuitBreaking) {
Expand Down Expand Up @@ -697,7 +731,11 @@ public String toString() {
}

EquivalentAddressGroup eag = new EquivalentAddressGroup(new FakeSocketAddress(name),
Attributes.newBuilder().set(InternalXdsAttributes.ATTR_LOCALITY, locality).build());
Attributes.newBuilder()
.set(InternalXdsAttributes.ATTR_LOCALITY, locality)
// Unique but arbitrary string
.set(InternalXdsAttributes.ATTR_LOCALITY_NAME, locality.toString())
.build());
return AddressFilter.setPathFilter(eag, Collections.singletonList(locality.toString()));
}

Expand Down

0 comments on commit 077dcbf

Please sign in to comment.