diff --git a/src/main/java/me/dinowernli/grpc/prometheus/MonitoringServerCall.java b/src/main/java/me/dinowernli/grpc/prometheus/MonitoringServerCall.java index ed68d87..ec0de06 100644 --- a/src/main/java/me/dinowernli/grpc/prometheus/MonitoringServerCall.java +++ b/src/main/java/me/dinowernli/grpc/prometheus/MonitoringServerCall.java @@ -9,9 +9,9 @@ import io.grpc.ForwardingServerCall; import io.grpc.Metadata; import io.grpc.MethodDescriptor.MethodType; -import me.dinowernli.grpc.prometheus.MonitoringServerInterceptor.Configuration; import io.grpc.ServerCall; import io.grpc.Status; +import me.dinowernli.grpc.prometheus.MonitoringServerInterceptor.Configuration; /** * A {@link ForwardingServerCall} which update Prometheus metrics based on the server-side actions diff --git a/src/main/java/me/dinowernli/grpc/prometheus/MonitoringServerCallListener.java b/src/main/java/me/dinowernli/grpc/prometheus/MonitoringServerCallListener.java index 3a43a4d..20392d5 100644 --- a/src/main/java/me/dinowernli/grpc/prometheus/MonitoringServerCallListener.java +++ b/src/main/java/me/dinowernli/grpc/prometheus/MonitoringServerCallListener.java @@ -1,6 +1,7 @@ package me.dinowernli.grpc.prometheus; import io.grpc.ForwardingServerCallListener; +import io.grpc.MethodDescriptor.MethodType; import io.grpc.ServerCall; /** @@ -10,12 +11,14 @@ class MonitoringServerCallListener extends ForwardingServerCallListener { private final ServerCall.Listener delegate; + private final MethodType methodType; private final ServerMetrics serverMetrics; MonitoringServerCallListener( - ServerCall.Listener delegate, ServerMetrics serverMetrics) { + ServerCall.Listener delegate, ServerMetrics serverMetrics, MethodType methodType) { this.delegate = delegate; this.serverMetrics = serverMetrics; + this.methodType = methodType; } @Override @@ -25,7 +28,9 @@ protected ServerCall.Listener delegate() { @Override public void onMessage(R request) { - serverMetrics.recordMessageReceived(); + if (methodType == MethodType.CLIENT_STREAMING || methodType == MethodType.BIDI_STREAMING) { + serverMetrics.recordStreamMessageReceived(); + } super.onMessage(request); } } \ No newline at end of file diff --git a/src/main/java/me/dinowernli/grpc/prometheus/MonitoringServerInterceptor.java b/src/main/java/me/dinowernli/grpc/prometheus/MonitoringServerInterceptor.java index 3b9bed5..edb6f01 100644 --- a/src/main/java/me/dinowernli/grpc/prometheus/MonitoringServerInterceptor.java +++ b/src/main/java/me/dinowernli/grpc/prometheus/MonitoringServerInterceptor.java @@ -16,14 +16,18 @@ public class MonitoringServerInterceptor implements ServerInterceptor { private final Clock clock; private final Configuration configuration; + private final ServerMetrics.Factory serverMetricsFactory; public static MonitoringServerInterceptor create(Configuration configuration) { - return new MonitoringServerInterceptor(Clock.systemDefaultZone(), configuration); + return new MonitoringServerInterceptor( + Clock.systemDefaultZone(), configuration, new ServerMetrics.Factory(configuration)); } - private MonitoringServerInterceptor(Clock clock, Configuration configuration) { + private MonitoringServerInterceptor( + Clock clock, Configuration configuration, ServerMetrics.Factory serverMetricsFactory) { this.clock = clock; this.configuration = configuration; + this.serverMetricsFactory = serverMetricsFactory; } @Override @@ -32,13 +36,11 @@ public ServerCall.Listener interceptCall( ServerCall call, Metadata requestHeaders, ServerCallHandler next) { - // TODO(dino): If we cache the ServerMetrics instance, we can achieve an initial 0 value on - // registration and save some cycles here where we always create a new one per-request. - ServerMetrics metrics = ServerMetrics.create(method, configuration.getCollectorRegistry()); + ServerMetrics metrics = serverMetricsFactory.createMetricsForMethod(method); ServerCall monitoringCall = new MonitoringServerCall( call, clock, method.getType(), metrics, configuration); return new MonitoringServerCallListener( - next.startCall(method, monitoringCall, requestHeaders), metrics); + next.startCall(method, monitoringCall, requestHeaders), metrics, method.getType()); } /** @@ -61,7 +63,7 @@ public static Configuration cheapMetricsOnly() { */ public static Configuration allMetrics() { return new Configuration( - false /* isIncludeLatencyHistograms */, Optional.empty() /* collectorRegistry */); + true /* isIncludeLatencyHistograms */, Optional.empty() /* collectorRegistry */); } /** @@ -78,8 +80,8 @@ public boolean isIncludeLatencyHistograms() { } /** Returns the {@link CollectorRegistry} used to record stats. */ - public Optional getCollectorRegistry() { - return collectorRegistry; + public CollectorRegistry getCollectorRegistry() { + return collectorRegistry.orElse(CollectorRegistry.defaultRegistry); } private Configuration( diff --git a/src/main/java/me/dinowernli/grpc/prometheus/ServerMetrics.java b/src/main/java/me/dinowernli/grpc/prometheus/ServerMetrics.java index b7e9e17..fe6d97c 100644 --- a/src/main/java/me/dinowernli/grpc/prometheus/ServerMetrics.java +++ b/src/main/java/me/dinowernli/grpc/prometheus/ServerMetrics.java @@ -11,6 +11,7 @@ import io.prometheus.client.Counter; import io.prometheus.client.Histogram; import io.prometheus.client.SimpleCollector; +import me.dinowernli.grpc.prometheus.MonitoringServerInterceptor.Configuration; /** * Prometheus metric definitions used for server-side monitoring of grpc services. @@ -47,53 +48,42 @@ class ServerMetrics { .subsystem("server") .name("msg_received_total") .labelNames("grpc_type", "grpc_service", "grpc_method") - .help("Total number of messages received from the client."); + .help("Total number of stream messages received from the client."); private static final Counter.Builder serverStreamMessagesSentBuilder = Counter.build() .namespace("grpc") .subsystem("server") .name("msg_sent_total") .labelNames("grpc_type", "grpc_service", "grpc_method") - .help("Total number of gRPC stream messages sent by the server."); + .help("Total number of stream messages sent by the server."); private final Counter serverStarted; private final Counter serverHandled; - private final Histogram serverHandledLatencySeconds; private final Counter serverStreamMessagesReceived; private final Counter serverStreamMessagesSent; + private final Optional serverHandledLatencySeconds; private final String methodTypeLabel; private final String serviceNameLabel; private final String methodNameLabel; - /** - * Creates an instance of {@link ServerMetrics} for the supplied method. If the - * {@link CollectorRegistry} is empty, the default global registry is used. - */ - static ServerMetrics create( - MethodDescriptor method, Optional collectorRegistry) { - CollectorRegistry registry = collectorRegistry.orElse(CollectorRegistry.defaultRegistry); - String serviceName = MethodDescriptor.extractFullServiceName(method.getFullMethodName()); - - // Full method names are of the form: "full.serviceName/MethodName". We extract the last part. - String methodName = method.getFullMethodName().substring(serviceName.length() + 1); - return new ServerMetrics(method.getType().toString(), serviceName, methodName, registry); - } - private ServerMetrics( String methodTypeLabel, String serviceNameLabel, String methodNameLabel, - CollectorRegistry registry) { + Counter serverStarted, + Counter serverHandled, + Counter serverStreamMessagesReceived, + Counter serverStreamMessagesSent, + Optional serverHandledLatencySeconds) { this.methodNameLabel = methodNameLabel; this.methodTypeLabel = methodTypeLabel; this.serviceNameLabel = serviceNameLabel; - - this.serverStarted = serverStartedBuilder.register(registry); - this.serverHandled = serverHandledBuilder.register(registry); - this.serverHandledLatencySeconds = serverHandledLatencySecondsBuilder.register(registry); - this.serverStreamMessagesReceived = serverStreamMessagesReceivedBuilder.register(registry); - this.serverStreamMessagesSent = serverStreamMessagesSentBuilder.register(registry); + this.serverStarted = serverStarted; + this.serverHandled = serverHandled; + this.serverStreamMessagesReceived = serverStreamMessagesReceived; + this.serverStreamMessagesSent = serverStreamMessagesSent; + this.serverHandledLatencySeconds = serverHandledLatencySeconds; } public void recordCallStarted() { @@ -108,12 +98,62 @@ public void recordStreamMessageSent() { addLabels(serverStreamMessagesSent).inc(); } + public void recordStreamMessageReceived() { + addLabels(serverStreamMessagesReceived).inc(); + } + + /** + * Only has any effect if monitoring is configured to include latency histograms. Otherwise, this + * does nothing. + */ public void recordLatency(double latencySec) { - addLabels(serverHandledLatencySeconds).observe(latencySec); + if (!this.serverHandledLatencySeconds.isPresent()) { + return; + } + addLabels(this.serverHandledLatencySeconds.get()).observe(latencySec); } - public void recordMessageReceived() { - addLabels(serverStreamMessagesReceived).inc(); + /** + * Knows how to produce {@link ServerMetrics} instances for individual methods. + */ + static class Factory { + private final Counter serverStarted; + private final Counter serverHandled; + private final Counter serverStreamMessagesReceived; + private final Counter serverStreamMessagesSent; + private final Optional serverHandledLatencySeconds; + + Factory(Configuration configuration) { + CollectorRegistry registry = configuration.getCollectorRegistry(); + this.serverStarted = serverStartedBuilder.register(registry); + this.serverHandled = serverHandledBuilder.register(registry); + this.serverStreamMessagesReceived = serverStreamMessagesReceivedBuilder.register(registry); + this.serverStreamMessagesSent = serverStreamMessagesSentBuilder.register(registry); + + if (configuration.isIncludeLatencyHistograms()) { + this.serverHandledLatencySeconds = + Optional.of(serverHandledLatencySecondsBuilder.register(registry)); + } else { + this.serverHandledLatencySeconds = Optional.empty(); + } + } + + /** Creates a {@link ServerMetrics} for the supplied method. */ + public ServerMetrics createMetricsForMethod(MethodDescriptor method) { + String serviceName = MethodDescriptor.extractFullServiceName(method.getFullMethodName()); + + // Full method names are of the form: "full.serviceName/MethodName". We extract the last part. + String methodName = method.getFullMethodName().substring(serviceName.length() + 1); + return new ServerMetrics( + method.getType().toString(), + serviceName, + methodName, + serverStarted, + serverHandled, + serverStreamMessagesReceived, + serverStreamMessagesSent, + serverHandledLatencySeconds); + } } private T addLabels(SimpleCollector collector, String... labels) { diff --git a/src/test/java/me/dinowernli/grpc/prometheus/integration/MonitoringInterceptorIntegrationTest.java b/src/test/java/me/dinowernli/grpc/prometheus/integration/MonitoringInterceptorIntegrationTest.java index 0ff105c..ce7361c 100644 --- a/src/test/java/me/dinowernli/grpc/prometheus/integration/MonitoringInterceptorIntegrationTest.java +++ b/src/test/java/me/dinowernli/grpc/prometheus/integration/MonitoringInterceptorIntegrationTest.java @@ -6,6 +6,7 @@ import java.io.IOException; import java.util.Enumeration; +import java.util.Optional; import org.junit.After; import org.junit.Before; @@ -28,10 +29,10 @@ import io.grpc.testing.StreamRecorder; import io.grpc.testing.TestUtils; import io.prometheus.client.Collector.MetricFamilySamples; +import io.prometheus.client.CollectorRegistry; import me.dinowernli.grpc.prometheus.MonitoringServerInterceptor; import me.dinowernli.grpc.prometheus.MonitoringServerInterceptor.Configuration; import me.dinowernli.grpc.prometheus.testing.HelloServiceImpl; -import io.prometheus.client.CollectorRegistry; /** * Integrations tests which make sure that if a service is started with a @@ -50,6 +51,9 @@ public class MonitoringInterceptorIntegrationTest { .setRecipient(RECIPIENT) .build(); + private static final Configuration CHEAP_METRICS = Configuration.cheapMetricsOnly(); + private static final Configuration ALL_METRICS = Configuration.allMetrics(); + private CollectorRegistry collectorRegistry; private Server grpcServer; private int grpcPort; @@ -57,7 +61,6 @@ public class MonitoringInterceptorIntegrationTest { @Before public void setUp() { collectorRegistry = new CollectorRegistry(); - startGrpcServer(); } @After @@ -67,9 +70,14 @@ public void tearDown() throws Exception { @Test public void unaryRpcMetrics() throws Throwable { + startGrpcServer(CHEAP_METRICS); createGrpcBlockingStub().sayHello(REQUEST); - MetricFamilySamples handled = findRecordedMetric("grpc_server_handled_total"); + assertThat(findRecordedMetricOrThrow("grpc_server_started_total").samples).hasSize(1); + assertThat(findRecordedMetricOrThrow("grpc_server_msg_received_total").samples).isEmpty(); + assertThat(findRecordedMetricOrThrow("grpc_server_msg_sent_total").samples).isEmpty(); + + MetricFamilySamples handled = findRecordedMetricOrThrow("grpc_server_handled_total"); assertThat(handled.samples).hasSize(1); assertThat(handled.samples.get(0).labelValues).containsExactly( "UNARY", SERVICE_NAME, UNARY_METHOD_NAME, "OK"); @@ -78,6 +86,7 @@ public void unaryRpcMetrics() throws Throwable { @Test public void clientStreamRpcMetrics() throws Throwable { + startGrpcServer(CHEAP_METRICS); StreamRecorder streamRecorder = StreamRecorder.create(); StreamObserver requestStream = createGrpcStub().sayHelloClientStream(streamRecorder); @@ -88,7 +97,11 @@ public void clientStreamRpcMetrics() throws Throwable { // Not a blocking stub, so we need to wait. streamRecorder.awaitCompletion(); - MetricFamilySamples handled = findRecordedMetric("grpc_server_handled_total"); + assertThat(findRecordedMetricOrThrow("grpc_server_started_total").samples).hasSize(1); + assertThat(findRecordedMetricOrThrow("grpc_server_msg_received_total").samples).hasSize(1); + assertThat(findRecordedMetricOrThrow("grpc_server_msg_sent_total").samples).isEmpty(); + + MetricFamilySamples handled = findRecordedMetricOrThrow("grpc_server_handled_total"); assertThat(handled.samples).hasSize(1); assertThat(handled.samples.get(0).labelValues).containsExactly( "CLIENT_STREAMING", SERVICE_NAME, CLIENT_STREAM_METHOD_NAME, "OK"); @@ -97,16 +110,21 @@ public void clientStreamRpcMetrics() throws Throwable { @Test public void serverStreamRpcMetrics() throws Throwable { + startGrpcServer(CHEAP_METRICS); ImmutableList responses = ImmutableList.copyOf(createGrpcBlockingStub().sayHelloServerStream(REQUEST)); - MetricFamilySamples handled = findRecordedMetric("grpc_server_handled_total"); + assertThat(findRecordedMetricOrThrow("grpc_server_started_total").samples).hasSize(1); + assertThat(findRecordedMetricOrThrow("grpc_server_msg_received_total").samples).isEmpty(); + assertThat(findRecordedMetricOrThrow("grpc_server_msg_sent_total").samples).hasSize(1); + + MetricFamilySamples handled = findRecordedMetricOrThrow("grpc_server_handled_total"); assertThat(handled.samples).hasSize(1); assertThat(handled.samples.get(0).labelValues).containsExactly( "SERVER_STREAMING", SERVICE_NAME, SERVER_STREAM_METHOD_NAME, "OK"); assertThat(handled.samples.get(0).value).isWithin(0).of(1); - MetricFamilySamples messagesSent = findRecordedMetric("grpc_server_msg_sent_total"); + MetricFamilySamples messagesSent = findRecordedMetricOrThrow("grpc_server_msg_sent_total"); assertThat(messagesSent.samples.get(0).labelValues).containsExactly( "SERVER_STREAMING", SERVICE_NAME, SERVER_STREAM_METHOD_NAME); assertThat(messagesSent.samples.get(0).value).isWithin(0).of(responses.size()); @@ -114,6 +132,7 @@ public void serverStreamRpcMetrics() throws Throwable { @Test public void bidiStreamRpcMetrics() throws Throwable { + startGrpcServer(CHEAP_METRICS); StreamRecorder streamRecorder = StreamRecorder.create(); StreamObserver requestStream = createGrpcStub().sayHelloBidiStream(streamRecorder); @@ -124,7 +143,11 @@ public void bidiStreamRpcMetrics() throws Throwable { // Not a blocking stub, so we need to wait. streamRecorder.awaitCompletion(); - MetricFamilySamples handled = findRecordedMetric("grpc_server_handled_total"); + assertThat(findRecordedMetricOrThrow("grpc_server_started_total").samples).hasSize(1); + assertThat(findRecordedMetricOrThrow("grpc_server_msg_received_total").samples).hasSize(1); + assertThat(findRecordedMetricOrThrow("grpc_server_msg_sent_total").samples).hasSize(1); + + MetricFamilySamples handled = findRecordedMetricOrThrow("grpc_server_handled_total"); assertThat(handled.samples).hasSize(1); assertThat(handled.samples.get(0).labelValues).containsExactly( "BIDI_STREAMING", SERVICE_NAME, BIDI_STREAM_METHOD_NAME, "OK"); @@ -132,15 +155,44 @@ public void bidiStreamRpcMetrics() throws Throwable { } @Test - public void doesNotAddHistogramsIfDisabled() throws Throwable { + public void noHistogramIfDisabled() throws Throwable { + startGrpcServer(CHEAP_METRICS); + createGrpcBlockingStub().sayHello(REQUEST); + assertThat(findRecordedMetric("grpc_server_handled_latency_seconds").isPresent()).isFalse(); + } + + @Test + public void addsHistogramIfEnabled() throws Throwable { + startGrpcServer(ALL_METRICS); + createGrpcBlockingStub().sayHello(REQUEST); + + MetricFamilySamples latency = findRecordedMetricOrThrow("grpc_server_handled_latency_seconds"); + assertThat(latency.samples.size()).isGreaterThan(0); + } + + @Test + public void recordsMultipleCalls() throws Throwable { + startGrpcServer(CHEAP_METRICS); + + createGrpcBlockingStub().sayHello(REQUEST); + createGrpcBlockingStub().sayHello(REQUEST); createGrpcBlockingStub().sayHello(REQUEST); - // TODO(dino): Add plumbing for disabling histograms in the test. + StreamRecorder streamRecorder = StreamRecorder.create(); + StreamObserver requestStream = + createGrpcStub().sayHelloBidiStream(streamRecorder); + requestStream.onNext(REQUEST); + requestStream.onNext(REQUEST); + requestStream.onCompleted(); + streamRecorder.awaitCompletion(); + + assertThat(findRecordedMetricOrThrow("grpc_server_started_total").samples).hasSize(2); + assertThat(findRecordedMetricOrThrow("grpc_server_handled_total").samples).hasSize(2); } - private void startGrpcServer() { + private void startGrpcServer(Configuration monitoringConfig) { MonitoringServerInterceptor interceptor = MonitoringServerInterceptor.create( - Configuration.cheapMetricsOnly().withCollectorRegistry(collectorRegistry)); + monitoringConfig.withCollectorRegistry(collectorRegistry)); grpcPort = TestUtils.pickUnusedPort(); grpcServer = ServerBuilder.forPort(grpcPort) .addService(ServerInterceptors.intercept( @@ -153,15 +205,23 @@ private void startGrpcServer() { } } - private MetricFamilySamples findRecordedMetric(String name) { + private Optional findRecordedMetric(String name) { Enumeration samples = collectorRegistry.metricFamilySamples(); while (samples.hasMoreElements()) { MetricFamilySamples sample = samples.nextElement(); if (sample.name.equals(name)) { - return sample; + return Optional.of(sample); } } - throw new IllegalArgumentException("Could not find metric with name: " + name); + return Optional.empty(); + } + + private MetricFamilySamples findRecordedMetricOrThrow(String name) { + Optional result = findRecordedMetric(name); + if (!result.isPresent()){ + throw new IllegalArgumentException("Could not find metric with name: " + name); + } + return result.get(); } private HelloServiceBlockingStub createGrpcBlockingStub() {