Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions core/src/main/java/tech/ydb/core/grpc/GrpcTransportBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ public enum InitMode {

private byte[] cert = null;
private boolean useTLS = false;
private String applicationName = null;
private String clientProcessId = null;
private ManagedChannelFactory.Builder channelFactoryBuilder = null;
private final List<Consumer<? super ManagedChannelBuilder<?>>> channelInitializers = new ArrayList<>();
private Supplier<ScheduledExecutorService> schedulerFactory = YdbSchedulerFactory::createScheduler;
Expand Down Expand Up @@ -127,6 +129,14 @@ public String getVersionString() {
.orElse(Version.UNKNOWN_VERSION);
}

public String getApplicationName() {
return applicationName;
}

public String getClientProcessId() {
return clientProcessId;
}

public Supplier<ScheduledExecutorService> getSchedulerFactory() {
return schedulerFactory;
}
Expand Down Expand Up @@ -255,6 +265,16 @@ public GrpcTransportBuilder withSecureConnection() {
return this;
}

public GrpcTransportBuilder withApplicationName(String applicationName) {
this.applicationName = applicationName;
return this;
}

public GrpcTransportBuilder withClientProcessId(String clientProcessId) {
this.clientProcessId = clientProcessId;
return this;
}

public GrpcTransportBuilder withBalancingSettings(BalancingSettings balancingSettings) {
this.balancingSettings = balancingSettings;
return this;
Expand Down
23 changes: 23 additions & 0 deletions core/src/main/java/tech/ydb/core/grpc/YdbHeaders.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package tech.ydb.core.grpc;

import io.grpc.ClientInterceptor;
import io.grpc.Metadata;
import io.grpc.stub.MetadataUtils;


/**
Expand All @@ -22,5 +24,26 @@ public class YdbHeaders {
public static final Metadata.Key<String> YDB_SERVER_HINTS =
Metadata.Key.of("x-ydb-server-hints", Metadata.ASCII_STRING_MARSHALLER);

public static final Metadata.Key<String> APPLICATION_NAME =
Metadata.Key.of("x-ydb-application-name", Metadata.ASCII_STRING_MARSHALLER);

public static final Metadata.Key<String> CLIENT_PROCESS_ID =
Metadata.Key.of("x-ydb-client-pid", Metadata.ASCII_STRING_MARSHALLER);

private YdbHeaders() { }

public static ClientInterceptor createMetadataInterceptor(GrpcTransportBuilder builder) {
Metadata extraHeaders = new Metadata();
extraHeaders.put(YdbHeaders.DATABASE, builder.getDatabase());
extraHeaders.put(YdbHeaders.BUILD_INFO, builder.getVersionString());
String appName = builder.getApplicationName();
if (appName != null) {
extraHeaders.put(YdbHeaders.APPLICATION_NAME, appName);
}
String clientPid = builder.getClientProcessId();
if (clientPid != null) {
extraHeaders.put(YdbHeaders.CLIENT_PROCESS_ID, clientPid);
}
return MetadataUtils.newAttachHeadersInterceptor(extraHeaders);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
import io.grpc.ClientInterceptor;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.internal.DnsNameResolverProvider;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NegotiationType;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.MetadataUtils;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelOption;
import io.netty.handler.ssl.SslContext;
Expand All @@ -33,8 +31,8 @@ public class NettyChannelFactory implements ManagedChannelFactory {
static final int INBOUND_MESSAGE_SIZE = 64 << 20; // 64 MiB
static final String DEFAULT_BALANCER_POLICY = "round_robin";

private final String database;
private final String version;
private final ClientInterceptor metadata;

private final boolean useTLS;
private final byte[] cert;
private final boolean retryEnabled;
Expand All @@ -44,8 +42,7 @@ public class NettyChannelFactory implements ManagedChannelFactory {
private final List<Consumer<? super ManagedChannelBuilder<?>>> initializers;

private NettyChannelFactory(GrpcTransportBuilder builder) {
this.database = builder.getDatabase();
this.version = builder.getVersionString();
this.metadata = YdbHeaders.createMetadataInterceptor(builder);
this.useTLS = builder.getUseTls();
this.cert = builder.getCert();
this.retryEnabled = builder.isEnableRetry();
Expand Down Expand Up @@ -81,7 +78,7 @@ public ManagedChannel newManagedChannel(String host, int port, String sslHostOve
.maxInboundMessageSize(INBOUND_MESSAGE_SIZE)
.withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT)
.withOption(ChannelOption.TCP_NODELAY, true)
.intercept(metadataInterceptor());
.intercept(metadata);

if (!useDefaultGrpcResolver) {
// force usage of dns resolver and round_robin balancer
Expand Down Expand Up @@ -114,13 +111,6 @@ protected void configure(NettyChannelBuilder channelBuilder) {

}

private ClientInterceptor metadataInterceptor() {
Metadata extraHeaders = new Metadata();
extraHeaders.put(YdbHeaders.DATABASE, database);
extraHeaders.put(YdbHeaders.BUILD_INFO, version);
return MetadataUtils.newAttachHeadersInterceptor(extraHeaders);
}

private SslContext createSslContext() {
try {
SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import io.grpc.ClientInterceptor;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.internal.DnsNameResolverProvider;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NegotiationType;
Expand All @@ -19,7 +18,6 @@
import io.grpc.netty.shaded.io.netty.channel.ChannelOption;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import io.grpc.stub.MetadataUtils;

import tech.ydb.core.grpc.GrpcTransportBuilder;
import tech.ydb.core.grpc.YdbHeaders;
Expand All @@ -33,8 +31,8 @@ public class ShadedNettyChannelFactory implements ManagedChannelFactory {
static final int INBOUND_MESSAGE_SIZE = 64 << 20; // 64 MiB
static final String DEFAULT_BALANCER_POLICY = "round_robin";

private final String database;
private final String version;
private final ClientInterceptor metadata;

private final boolean useTLS;
private final byte[] cert;
private final boolean retryEnabled;
Expand All @@ -43,9 +41,8 @@ public class ShadedNettyChannelFactory implements ManagedChannelFactory {
private final Long grpcKeepAliveTimeMillis;
private final List<Consumer<? super ManagedChannelBuilder<?>>> initializers;

public ShadedNettyChannelFactory(GrpcTransportBuilder builder) {
this.database = builder.getDatabase();
this.version = builder.getVersionString();
private ShadedNettyChannelFactory(GrpcTransportBuilder builder) {
this.metadata = YdbHeaders.createMetadataInterceptor(builder);
this.useTLS = builder.getUseTls();
this.cert = builder.getCert();
this.retryEnabled = builder.isEnableRetry();
Expand Down Expand Up @@ -81,7 +78,7 @@ public ManagedChannel newManagedChannel(String host, int port, String sslHostOve
.maxInboundMessageSize(INBOUND_MESSAGE_SIZE)
.withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT)
.withOption(ChannelOption.TCP_NODELAY, true)
.intercept(metadataInterceptor());
.intercept(metadata);

if (!useDefaultGrpcResolver) {
// force usage of dns resolver and round_robin balancer
Expand Down Expand Up @@ -114,13 +111,6 @@ protected void configure(NettyChannelBuilder channelBuilder) {

}

private ClientInterceptor metadataInterceptor() {
Metadata extraHeaders = new Metadata();
extraHeaders.put(YdbHeaders.DATABASE, database);
extraHeaders.put(YdbHeaders.BUILD_INFO, version);
return MetadataUtils.newAttachHeadersInterceptor(extraHeaders);
}

private SslContext createSslContext() {
try {
SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,27 @@
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingChannelBuilder2;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.netty.shaded.io.grpc.netty.NegotiationType;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.netty.buffer.ByteBufAllocator;
import io.grpc.netty.shaded.io.netty.channel.ChannelOption;
import io.grpc.netty.shaded.io.netty.handler.ssl.util.SelfSignedCertificate;
import io.grpc.stub.MetadataUtils;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;

import tech.ydb.core.grpc.GrpcTransport;
import tech.ydb.core.grpc.GrpcTransportBuilder;
import tech.ydb.core.grpc.YdbHeaders;
import tech.ydb.core.utils.Version;

/**
*
Expand All @@ -41,15 +46,21 @@ public class DefaultChannelFactoryTest {

private AutoCloseable mocks;
private MockedStatic<NettyChannelBuilder> channelStaticMock;
private MockedStatic<MetadataUtils> metadataUtilsStaticMock;
private final NettyChannelBuilder channelBuilderMock = Mockito.mock(NettyChannelBuilder.class);
private final ManagedChannel channelMock = Mockito.mock(ManagedChannel.class);
private final ArgumentCaptor<Metadata> metadataCapture = ArgumentCaptor.forClass(Metadata.class);
private final ClientInterceptor clientInterceptor = Mockito.mock(ClientInterceptor.class);

@Before
@SuppressWarnings("deprecation")
public void setUp() {
mocks = MockitoAnnotations.openMocks(this);
channelStaticMock = Mockito.mockStatic(NettyChannelBuilder.class);
channelStaticMock.when(FOR_ADDRESS).thenReturn(channelBuilderMock);
metadataUtilsStaticMock = Mockito.mockStatic(MetadataUtils.class);
metadataUtilsStaticMock.when(() -> MetadataUtils.newAttachHeadersInterceptor(metadataCapture.capture()))
.thenReturn(clientInterceptor);

Mockito.when(channelBuilderMock.negotiationType(ArgumentMatchers.any()))
.thenReturn(channelBuilderMock);
Expand All @@ -70,6 +81,7 @@ public void setUp() {
@After
public void tearDown() throws Exception {
channelStaticMock.close();
metadataUtilsStaticMock.close();
mocks.close();
}

Expand All @@ -94,6 +106,13 @@ public void defaultParams() {
.withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT);
Mockito.verify(channelBuilderMock, Mockito.times(0)).enableRetry();
Mockito.verify(channelBuilderMock, Mockito.times(1)).disableRetry();


Metadata metadata = metadataCapture.getValue();
Assert.assertEquals("/Root", metadata.get(YdbHeaders.DATABASE));
Assert.assertEquals("ydb-java-sdk/" + Version.getVersion().get(), metadata.get(YdbHeaders.BUILD_INFO));
Assert.assertNull(metadata.get(YdbHeaders.APPLICATION_NAME));
Assert.assertNull(metadata.get(YdbHeaders.CLIENT_PROCESS_ID));
}

@Test
Expand Down Expand Up @@ -123,6 +142,23 @@ public void defaultSslFactory() {
Mockito.verify(channelBuilderMock, Mockito.times(0)).disableRetry();
}

@Test
public void customHeadersTest() {
GrpcTransportBuilder builder = GrpcTransport.forHost(MOCKED_HOST, MOCKED_PORT, "/Root")
.withApplicationName("test-application")
.withClientProcessId("client-hostname");
ManagedChannelFactory factory = ChannelFactoryLoader.load().buildFactory(builder);

Assert.assertSame(channelMock, factory.newManagedChannel(MOCKED_HOST, MOCKED_PORT, null));
channelStaticMock.verify(FOR_ADDRESS, Mockito.times(1));

Metadata metadata = metadataCapture.getValue();
Assert.assertEquals("/Root", metadata.get(YdbHeaders.DATABASE));
Assert.assertEquals("ydb-java-sdk/" + Version.getVersion().get(), metadata.get(YdbHeaders.BUILD_INFO));
Assert.assertEquals("test-application", metadata.get(YdbHeaders.APPLICATION_NAME));
Assert.assertEquals("client-hostname", metadata.get(YdbHeaders.CLIENT_PROCESS_ID));
}

@Test
public void customChannelInitializer() {
GrpcTransportBuilder builder = GrpcTransport.forHost(MOCKED_HOST, MOCKED_PORT, "/Root")
Expand Down