Skip to content

Commit

Permalink
Merge pull request #447 from yidongnan/metrics/cancellations
Browse files Browse the repository at this point in the history
Use ServerCall.Listener#onComplete for metric collection instead of ServerCall#close
  • Loading branch information
yidongnan committed Nov 12, 2020
2 parents aa286ab + 5e23eed commit 28d304d
Show file tree
Hide file tree
Showing 13 changed files with 370 additions and 65 deletions.
Expand Up @@ -17,15 +17,13 @@

package net.devh.boot.grpc.client.metric;

import java.util.function.Function;
import java.util.function.Consumer;

import io.grpc.ClientCall;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.Metadata;
import io.grpc.Status.Code;
import io.grpc.Status;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Timer;

/**
* A simple forwarding client call that collects metrics for micrometer.
Expand All @@ -36,35 +34,37 @@
*/
class MetricCollectingClientCall<Q, A> extends SimpleForwardingClientCall<Q, A> {

private final MeterRegistry registry;
private final Counter requestCounter;
private final Counter responseCounter;
private final Function<Code, Timer> timerFunction;
private final Consumer<Status.Code> processingDurationTiming;

/**
* Creates a new delegating ClientCall that will wrap the given client call to collect metrics.
*
* @param delegate The original call to wrap.
* @param registry The registry to save the metrics to.
* @param requestCounter The counter for outgoing requests.
* @param responseCounter The counter for incoming responses.
* @param timerFunction A function that will return a timer for a given status code.
* @param processingDurationTiming The consumer used to time the processing duration along with a response status.
*/
public MetricCollectingClientCall(final ClientCall<Q, A> delegate, final MeterRegistry registry,
final Counter requestCounter, final Counter responseCounter,
final Function<Code, Timer> timerFunction) {
public MetricCollectingClientCall(
final ClientCall<Q, A> delegate,
final Counter requestCounter,
final Counter responseCounter,
final Consumer<Status.Code> processingDurationTiming) {

super(delegate);
this.registry = registry;
this.requestCounter = requestCounter;
this.responseCounter = responseCounter;
this.timerFunction = timerFunction;
this.processingDurationTiming = processingDurationTiming;
}

@Override
public void start(final ClientCall.Listener<A> responseListener, final Metadata metadata) {
super.start(
new MetricCollectingClientCallListener<>(responseListener, this.registry, this.responseCounter,
this.timerFunction),
new MetricCollectingClientCallListener<>(
responseListener,
this.responseCounter,
this.processingDurationTiming),
metadata);
}

Expand Down
Expand Up @@ -17,16 +17,13 @@

package net.devh.boot.grpc.client.metric;

import java.util.function.Function;
import java.util.function.Consumer;

import io.grpc.ClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Timer;

/**
* A simple forwarding client call listener that collects metrics for micrometer.
Expand All @@ -36,32 +33,29 @@
*/
class MetricCollectingClientCallListener<A> extends SimpleForwardingClientCallListener<A> {

private final Timer.Sample timerSample;
private final Counter responseCounter;
private final Function<Code, Timer> timerFunction;
private final Consumer<Status.Code> processingDurationTiming;

/**
* Creates a new delegating ClientCallListener that will wrap the given client call listener to collect metrics.
*
* @param delegate The original call to wrap.
* @param registry The registry to save the metrics to.
* @param responseCounter The counter for incoming responses.
* @param timerFunction A function that will return a timer for a given status code.
* @param processingDurationTiming The consumer used to time the processing duration along with a response status.
*/
public MetricCollectingClientCallListener(
final ClientCall.Listener<A> delegate,
final MeterRegistry registry,
final Counter responseCounter,
final Function<Code, Timer> timerFunction) {
final Consumer<Status.Code> processingDurationTiming) {

super(delegate);
this.responseCounter = responseCounter;
this.timerFunction = timerFunction;
this.timerSample = Timer.start(registry);
this.processingDurationTiming = processingDurationTiming;
}

@Override
public void onClose(final Status status, final Metadata metadata) {
this.timerSample.stop(this.timerFunction.apply(status.getCode()));
this.processingDurationTiming.accept(status.getCode());
super.onClose(status, metadata);
}

Expand Down
Expand Up @@ -23,6 +23,7 @@
import static net.devh.boot.grpc.common.metric.MetricUtils.prepareCounterFor;
import static net.devh.boot.grpc.common.metric.MetricUtils.prepareTimerFor;

import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.UnaryOperator;

Expand Down Expand Up @@ -104,15 +105,19 @@ protected Function<Code, Timer> newTimerFunction(final MethodDescriptor<?, ?> me
}

@Override
public <Q, A> ClientCall<Q, A> interceptCall(final MethodDescriptor<Q, A> methodDescriptor,
final CallOptions callOptions, final Channel channel) {
public <Q, A> ClientCall<Q, A> interceptCall(
final MethodDescriptor<Q, A> methodDescriptor,
final CallOptions callOptions,
final Channel channel) {

final MetricSet metrics = metricsFor(methodDescriptor);
final Consumer<Code> processingDurationTiming = metrics.newProcessingDurationTiming(this.registry);

return new MetricCollectingClientCall<>(
channel.newCall(methodDescriptor, callOptions),
this.registry,
metrics.getRequestCounter(),
metrics.getResponseCounter(),
metrics.getTimerFunction());
processingDurationTiming);
}

}
Expand Up @@ -22,16 +22,19 @@
import java.util.EnumMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.function.UnaryOperator;

import io.grpc.MethodDescriptor;
import io.grpc.ServiceDescriptor;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Timer;
import io.micrometer.core.instrument.Timer.Sample;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

Expand Down Expand Up @@ -185,13 +188,29 @@ protected static class MetricSet {
* @param responseCounter The response counter to use.
* @param timerFunction The timer function to use.
*/
public MetricSet(final Counter requestCounter, final Counter responseCounter,
public MetricSet(
final Counter requestCounter,
final Counter responseCounter,
final Function<Code, Timer> timerFunction) {

this.requestCounter = requestCounter;
this.responseCounter = responseCounter;
this.timerFunction = timerFunction;
}

/**
* Uses the given registry to create a {@link Sample Timer.Sample} that will be reported if the returned
* consumer is invoked.
*
* @param registry The registry used to create the sample.
* @return The newly created consumer that will report the processing duration since calling this method and
* invoking the returned consumer along with the status code.
*/
public Consumer<Status.Code> newProcessingDurationTiming(final MeterRegistry registry) {
final Timer.Sample timerSample = Timer.start(registry);
return code -> timerSample.stop(this.timerFunction.apply(code));
}

}

}
Expand Up @@ -17,16 +17,12 @@

package net.devh.boot.grpc.server.metric;

import java.util.function.Function;

import io.grpc.ForwardingServerCall.SimpleForwardingServerCall;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Timer;

/**
* A simple forwarding server call that collects metrics for micrometer.
Expand All @@ -38,29 +34,29 @@
class MetricCollectingServerCall<Q, A> extends SimpleForwardingServerCall<Q, A> {

private final Counter responseCounter;
private final Function<Code, Timer> timerFunction;
private final Timer.Sample timerSample;
private Code responseCode = Code.UNKNOWN;

/**
* Creates a new delegating ServerCall that will wrap the given server call to collect metrics.
*
* @param delegate The original call to wrap.
* @param registry The registry to save the metrics to.
* @param responseCounter The counter for incoming responses.
* @param timerFunction A function that will return a timer for a given status code.
*/
public MetricCollectingServerCall(final ServerCall<Q, A> delegate, final MeterRegistry registry,
final Counter responseCounter,
final Function<Code, Timer> timerFunction) {
public MetricCollectingServerCall(
final ServerCall<Q, A> delegate,
final Counter responseCounter) {

super(delegate);
this.responseCounter = responseCounter;
this.timerFunction = timerFunction;
this.timerSample = Timer.start(registry);
}

public Code getResponseCode() {
return this.responseCode;
}

@Override
public void close(final Status status, final Metadata responseHeaders) {
this.timerSample.stop(this.timerFunction.apply(status.getCode()));
this.responseCode = status.getCode();
super.close(status, responseHeaders);
}

Expand Down
Expand Up @@ -17,8 +17,12 @@

package net.devh.boot.grpc.server.metric;

import java.util.function.Consumer;
import java.util.function.Supplier;

import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.Status;
import io.micrometer.core.instrument.Counter;

/**
Expand All @@ -30,16 +34,28 @@
class MetricCollectingServerCallListener<Q> extends SimpleForwardingServerCallListener<Q> {

private final Counter requestCounter;
private final Supplier<Status.Code> responseCodeSupplier;
private final Consumer<Status.Code> responseStatusTiming;

/**
* Creates a new delegating ServerCallListener that will wrap the given server call listener to collect metrics.
*
* @param delegate The original listener to wrap.
* @param requestCounter The counter for incoming requests.
* @param responseCodeSupplier The supplier of the response code.
* @param responseStatusTiming The consumer used to time the processing duration along with a response status.
*/
public MetricCollectingServerCallListener(final ServerCall.Listener<Q> delegate, final Counter requestCounter) {

public MetricCollectingServerCallListener(
final Listener<Q> delegate,
final Counter requestCounter,
final Supplier<Status.Code> responseCodeSupplier,
final Consumer<Status.Code> responseStatusTiming) {

super(delegate);
this.requestCounter = requestCounter;
this.responseCodeSupplier = responseCodeSupplier;
this.responseStatusTiming = responseStatusTiming;
}

@Override
Expand All @@ -48,4 +64,20 @@ public void onMessage(final Q requestMessage) {
super.onMessage(requestMessage);
}

@Override
public void onComplete() {
report(this.responseCodeSupplier.get());
super.onComplete();
}

@Override
public void onCancel() {
report(Status.Code.CANCELLED);
super.onCancel();
}

private void report(final Status.Code code) {
this.responseStatusTiming.accept(code);
}

}
Expand Up @@ -23,6 +23,7 @@
import static net.devh.boot.grpc.common.metric.MetricUtils.prepareCounterFor;
import static net.devh.boot.grpc.common.metric.MetricUtils.prepareTimerFor;

import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.UnaryOperator;

Expand All @@ -36,6 +37,7 @@
import io.grpc.ServerInterceptor;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServiceDescriptor;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
Expand Down Expand Up @@ -131,11 +133,18 @@ public <Q, A> ServerCall.Listener<Q> interceptCall(
final ServerCall<Q, A> call,
final Metadata requestHeaders,
final ServerCallHandler<Q, A> next) {

final MetricSet metrics = metricsFor(call.getMethodDescriptor());
final ServerCall<Q, A> monitoringCall = new MetricCollectingServerCall<>(call, this.registry,
metrics.getResponseCounter(), metrics.getTimerFunction());
final Consumer<Status.Code> responseStatusTiming = metrics.newProcessingDurationTiming(this.registry);

final MetricCollectingServerCall<Q, A> monitoringCall =
new MetricCollectingServerCall<>(call, metrics.getResponseCounter());

return new MetricCollectingServerCallListener<>(
next.startCall(monitoringCall, requestHeaders), metrics.getRequestCounter());
next.startCall(monitoringCall, requestHeaders),
metrics.getRequestCounter(),
monitoringCall::getResponseCode,
responseStatusTiming);
}

}
Expand Up @@ -39,7 +39,7 @@
class MetricCollectingClientInterceptorTest {

@Test
public void testClientPreRegistration() {
void testClientPreRegistration() {
log.info("--- Starting tests with client pre-registration ---");
final MeterRegistry meterRegistry = new SimpleMeterRegistry();
assertEquals(0, meterRegistry.getMeters().size());
Expand All @@ -52,7 +52,7 @@ public void testClientPreRegistration() {
}

@Test
public void testClientCustomization() {
void testClientCustomization() {
log.info("--- Starting tests with client customization ---");
final MeterRegistry meterRegistry = new SimpleMeterRegistry();
assertEquals(0, meterRegistry.getMeters().size());
Expand Down

0 comments on commit 28d304d

Please sign in to comment.