diff --git a/core/src/main/java/tech/ydb/core/grpc/GrpcTransportBuilder.java b/core/src/main/java/tech/ydb/core/grpc/GrpcTransportBuilder.java index 1d1af684a..1e270a684 100644 --- a/core/src/main/java/tech/ydb/core/grpc/GrpcTransportBuilder.java +++ b/core/src/main/java/tech/ydb/core/grpc/GrpcTransportBuilder.java @@ -71,6 +71,8 @@ public enum InitMode { private byte[] cert = null; private boolean useTLS = false; + private String applicationName = null; + private String clientProcessId = null; private ManagedChannelFactory.Builder channelFactoryBuilder = null; private final List>> channelInitializers = new ArrayList<>(); private Supplier schedulerFactory = YdbSchedulerFactory::createScheduler; @@ -127,6 +129,14 @@ public String getVersionString() { .orElse(Version.UNKNOWN_VERSION); } + public String getApplicationName() { + return applicationName; + } + + public String getClientProcessId() { + return clientProcessId; + } + public Supplier getSchedulerFactory() { return schedulerFactory; } @@ -255,6 +265,16 @@ public GrpcTransportBuilder withSecureConnection() { return this; } + public GrpcTransportBuilder withApplicationName(String applicationName) { + this.applicationName = applicationName; + return this; + } + + public GrpcTransportBuilder withClientProcessId(String clientProcessId) { + this.clientProcessId = clientProcessId; + return this; + } + public GrpcTransportBuilder withBalancingSettings(BalancingSettings balancingSettings) { this.balancingSettings = balancingSettings; return this; diff --git a/core/src/main/java/tech/ydb/core/grpc/YdbHeaders.java b/core/src/main/java/tech/ydb/core/grpc/YdbHeaders.java index 8e78a63f2..4208aaab4 100644 --- a/core/src/main/java/tech/ydb/core/grpc/YdbHeaders.java +++ b/core/src/main/java/tech/ydb/core/grpc/YdbHeaders.java @@ -1,6 +1,8 @@ package tech.ydb.core.grpc; +import io.grpc.ClientInterceptor; import io.grpc.Metadata; +import io.grpc.stub.MetadataUtils; /** @@ -22,5 +24,26 @@ public class YdbHeaders { public static final Metadata.Key YDB_SERVER_HINTS = Metadata.Key.of("x-ydb-server-hints", Metadata.ASCII_STRING_MARSHALLER); + public static final Metadata.Key APPLICATION_NAME = + Metadata.Key.of("x-ydb-application-name", Metadata.ASCII_STRING_MARSHALLER); + + public static final Metadata.Key CLIENT_PROCESS_ID = + Metadata.Key.of("x-ydb-client-pid", Metadata.ASCII_STRING_MARSHALLER); + private YdbHeaders() { } + + public static ClientInterceptor createMetadataInterceptor(GrpcTransportBuilder builder) { + Metadata extraHeaders = new Metadata(); + extraHeaders.put(YdbHeaders.DATABASE, builder.getDatabase()); + extraHeaders.put(YdbHeaders.BUILD_INFO, builder.getVersionString()); + String appName = builder.getApplicationName(); + if (appName != null) { + extraHeaders.put(YdbHeaders.APPLICATION_NAME, appName); + } + String clientPid = builder.getClientProcessId(); + if (clientPid != null) { + extraHeaders.put(YdbHeaders.CLIENT_PROCESS_ID, clientPid); + } + return MetadataUtils.newAttachHeadersInterceptor(extraHeaders); + } } diff --git a/core/src/main/java/tech/ydb/core/impl/pool/NettyChannelFactory.java b/core/src/main/java/tech/ydb/core/impl/pool/NettyChannelFactory.java index 6f2cd9ca3..2b6927896 100644 --- a/core/src/main/java/tech/ydb/core/impl/pool/NettyChannelFactory.java +++ b/core/src/main/java/tech/ydb/core/impl/pool/NettyChannelFactory.java @@ -10,12 +10,10 @@ import io.grpc.ClientInterceptor; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; -import io.grpc.Metadata; import io.grpc.internal.DnsNameResolverProvider; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.NegotiationType; import io.grpc.netty.NettyChannelBuilder; -import io.grpc.stub.MetadataUtils; import io.netty.buffer.ByteBufAllocator; import io.netty.channel.ChannelOption; import io.netty.handler.ssl.SslContext; @@ -33,8 +31,8 @@ public class NettyChannelFactory implements ManagedChannelFactory { static final int INBOUND_MESSAGE_SIZE = 64 << 20; // 64 MiB static final String DEFAULT_BALANCER_POLICY = "round_robin"; - private final String database; - private final String version; + private final ClientInterceptor metadata; + private final boolean useTLS; private final byte[] cert; private final boolean retryEnabled; @@ -44,8 +42,7 @@ public class NettyChannelFactory implements ManagedChannelFactory { private final List>> initializers; private NettyChannelFactory(GrpcTransportBuilder builder) { - this.database = builder.getDatabase(); - this.version = builder.getVersionString(); + this.metadata = YdbHeaders.createMetadataInterceptor(builder); this.useTLS = builder.getUseTls(); this.cert = builder.getCert(); this.retryEnabled = builder.isEnableRetry(); @@ -81,7 +78,7 @@ public ManagedChannel newManagedChannel(String host, int port, String sslHostOve .maxInboundMessageSize(INBOUND_MESSAGE_SIZE) .withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT) .withOption(ChannelOption.TCP_NODELAY, true) - .intercept(metadataInterceptor()); + .intercept(metadata); if (!useDefaultGrpcResolver) { // force usage of dns resolver and round_robin balancer @@ -114,13 +111,6 @@ protected void configure(NettyChannelBuilder channelBuilder) { } - private ClientInterceptor metadataInterceptor() { - Metadata extraHeaders = new Metadata(); - extraHeaders.put(YdbHeaders.DATABASE, database); - extraHeaders.put(YdbHeaders.BUILD_INFO, version); - return MetadataUtils.newAttachHeadersInterceptor(extraHeaders); - } - private SslContext createSslContext() { try { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); diff --git a/core/src/main/java/tech/ydb/core/impl/pool/ShadedNettyChannelFactory.java b/core/src/main/java/tech/ydb/core/impl/pool/ShadedNettyChannelFactory.java index d4c050f86..b05d8c47c 100644 --- a/core/src/main/java/tech/ydb/core/impl/pool/ShadedNettyChannelFactory.java +++ b/core/src/main/java/tech/ydb/core/impl/pool/ShadedNettyChannelFactory.java @@ -10,7 +10,6 @@ import io.grpc.ClientInterceptor; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; -import io.grpc.Metadata; import io.grpc.internal.DnsNameResolverProvider; import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; import io.grpc.netty.shaded.io.grpc.netty.NegotiationType; @@ -19,7 +18,6 @@ import io.grpc.netty.shaded.io.netty.channel.ChannelOption; import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; -import io.grpc.stub.MetadataUtils; import tech.ydb.core.grpc.GrpcTransportBuilder; import tech.ydb.core.grpc.YdbHeaders; @@ -33,8 +31,8 @@ public class ShadedNettyChannelFactory implements ManagedChannelFactory { static final int INBOUND_MESSAGE_SIZE = 64 << 20; // 64 MiB static final String DEFAULT_BALANCER_POLICY = "round_robin"; - private final String database; - private final String version; + private final ClientInterceptor metadata; + private final boolean useTLS; private final byte[] cert; private final boolean retryEnabled; @@ -43,9 +41,8 @@ public class ShadedNettyChannelFactory implements ManagedChannelFactory { private final Long grpcKeepAliveTimeMillis; private final List>> initializers; - public ShadedNettyChannelFactory(GrpcTransportBuilder builder) { - this.database = builder.getDatabase(); - this.version = builder.getVersionString(); + private ShadedNettyChannelFactory(GrpcTransportBuilder builder) { + this.metadata = YdbHeaders.createMetadataInterceptor(builder); this.useTLS = builder.getUseTls(); this.cert = builder.getCert(); this.retryEnabled = builder.isEnableRetry(); @@ -81,7 +78,7 @@ public ManagedChannel newManagedChannel(String host, int port, String sslHostOve .maxInboundMessageSize(INBOUND_MESSAGE_SIZE) .withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT) .withOption(ChannelOption.TCP_NODELAY, true) - .intercept(metadataInterceptor()); + .intercept(metadata); if (!useDefaultGrpcResolver) { // force usage of dns resolver and round_robin balancer @@ -114,13 +111,6 @@ protected void configure(NettyChannelBuilder channelBuilder) { } - private ClientInterceptor metadataInterceptor() { - Metadata extraHeaders = new Metadata(); - extraHeaders.put(YdbHeaders.DATABASE, database); - extraHeaders.put(YdbHeaders.BUILD_INFO, version); - return MetadataUtils.newAttachHeadersInterceptor(extraHeaders); - } - private SslContext createSslContext() { try { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); diff --git a/core/src/test/java/tech/ydb/core/impl/pool/DefaultChannelFactoryTest.java b/core/src/test/java/tech/ydb/core/impl/pool/DefaultChannelFactoryTest.java index 1f882a9c6..cad1ac1ef 100644 --- a/core/src/test/java/tech/ydb/core/impl/pool/DefaultChannelFactoryTest.java +++ b/core/src/test/java/tech/ydb/core/impl/pool/DefaultChannelFactoryTest.java @@ -12,15 +12,18 @@ import io.grpc.ClientInterceptor; import io.grpc.ForwardingChannelBuilder2; import io.grpc.ManagedChannel; +import io.grpc.Metadata; import io.grpc.netty.shaded.io.grpc.netty.NegotiationType; import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.shaded.io.netty.buffer.ByteBufAllocator; import io.grpc.netty.shaded.io.netty.channel.ChannelOption; import io.grpc.netty.shaded.io.netty.handler.ssl.util.SelfSignedCertificate; +import io.grpc.stub.MetadataUtils; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; import org.mockito.MockedStatic; import org.mockito.Mockito; @@ -28,6 +31,8 @@ import tech.ydb.core.grpc.GrpcTransport; import tech.ydb.core.grpc.GrpcTransportBuilder; +import tech.ydb.core.grpc.YdbHeaders; +import tech.ydb.core.utils.Version; /** * @@ -41,8 +46,11 @@ public class DefaultChannelFactoryTest { private AutoCloseable mocks; private MockedStatic channelStaticMock; + private MockedStatic metadataUtilsStaticMock; private final NettyChannelBuilder channelBuilderMock = Mockito.mock(NettyChannelBuilder.class); private final ManagedChannel channelMock = Mockito.mock(ManagedChannel.class); + private final ArgumentCaptor metadataCapture = ArgumentCaptor.forClass(Metadata.class); + private final ClientInterceptor clientInterceptor = Mockito.mock(ClientInterceptor.class); @Before @SuppressWarnings("deprecation") @@ -50,6 +58,9 @@ public void setUp() { mocks = MockitoAnnotations.openMocks(this); channelStaticMock = Mockito.mockStatic(NettyChannelBuilder.class); channelStaticMock.when(FOR_ADDRESS).thenReturn(channelBuilderMock); + metadataUtilsStaticMock = Mockito.mockStatic(MetadataUtils.class); + metadataUtilsStaticMock.when(() -> MetadataUtils.newAttachHeadersInterceptor(metadataCapture.capture())) + .thenReturn(clientInterceptor); Mockito.when(channelBuilderMock.negotiationType(ArgumentMatchers.any())) .thenReturn(channelBuilderMock); @@ -70,6 +81,7 @@ public void setUp() { @After public void tearDown() throws Exception { channelStaticMock.close(); + metadataUtilsStaticMock.close(); mocks.close(); } @@ -94,6 +106,13 @@ public void defaultParams() { .withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT); Mockito.verify(channelBuilderMock, Mockito.times(0)).enableRetry(); Mockito.verify(channelBuilderMock, Mockito.times(1)).disableRetry(); + + + Metadata metadata = metadataCapture.getValue(); + Assert.assertEquals("/Root", metadata.get(YdbHeaders.DATABASE)); + Assert.assertEquals("ydb-java-sdk/" + Version.getVersion().get(), metadata.get(YdbHeaders.BUILD_INFO)); + Assert.assertNull(metadata.get(YdbHeaders.APPLICATION_NAME)); + Assert.assertNull(metadata.get(YdbHeaders.CLIENT_PROCESS_ID)); } @Test @@ -123,6 +142,23 @@ public void defaultSslFactory() { Mockito.verify(channelBuilderMock, Mockito.times(0)).disableRetry(); } + @Test + public void customHeadersTest() { + GrpcTransportBuilder builder = GrpcTransport.forHost(MOCKED_HOST, MOCKED_PORT, "/Root") + .withApplicationName("test-application") + .withClientProcessId("client-hostname"); + ManagedChannelFactory factory = ChannelFactoryLoader.load().buildFactory(builder); + + Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null)); + channelStaticMock.verify(FOR_ADDRESS, Mockito.times(1)); + + Metadata metadata = metadataCapture.getValue(); + Assert.assertEquals("/Root", metadata.get(YdbHeaders.DATABASE)); + Assert.assertEquals("ydb-java-sdk/" + Version.getVersion().get(), metadata.get(YdbHeaders.BUILD_INFO)); + Assert.assertEquals("test-application", metadata.get(YdbHeaders.APPLICATION_NAME)); + Assert.assertEquals("client-hostname", metadata.get(YdbHeaders.CLIENT_PROCESS_ID)); + } + @Test public void customChannelInitializer() { GrpcTransportBuilder builder = GrpcTransport.forHost(MOCKED_HOST, MOCKED_PORT, "/Root")