diff --git a/api/api.go b/api/api.go index 22cc096c..b457f047 100644 --- a/api/api.go +++ b/api/api.go @@ -38,8 +38,8 @@ type API struct { Options *Options Config *config.Config PluginRegistry *plugin.Registry - Pools map[string]*pool.Pool - Proxies map[string]*network.Proxy + Pools map[string]map[string]*pool.Pool + Proxies map[string]map[string]*network.Proxy Servers map[string]*network.Server } @@ -205,12 +205,17 @@ func (a *API) GetPools(context.Context, *emptypb.Empty) (*structpb.Struct, error _, span := otel.Tracer(config.TracerName).Start(a.ctx, "Get Pools") defer span.End() - pools := make(map[string]interface{}) - for name, p := range a.Pools { - pools[name] = map[string]interface{}{ - "cap": p.Cap(), - "size": p.Size(), + pools := make(map[string]any) + + for configGroupName, configGroupPools := range a.Pools { + groupPools := make(map[string]any) + for name, p := range configGroupPools { + groupPools[name] = map[string]any{ + "cap": p.Cap(), + "size": p.Size(), + } } + pools[configGroupName] = groupPools } poolsConfig, err := structpb.NewStruct(pools) @@ -231,23 +236,31 @@ func (a *API) GetProxies(context.Context, *emptypb.Empty) (*structpb.Struct, err _, span := otel.Tracer(config.TracerName).Start(a.ctx, "Get Proxies") defer span.End() - proxies := make(map[string]interface{}) - for name, proxy := range a.Proxies { - available := make([]interface{}, 0) - for _, c := range proxy.AvailableConnectionsString() { - available = append(available, c) - } + // Create a new map to hold the flattened proxies data + proxies := make(map[string]any) - busy := make([]interface{}, 0) - for _, conn := range proxy.BusyConnectionsString() { - busy = append(busy, conn) - } + for configGroupName, configGroupProxies := range a.Proxies { + // Create a map for each configuration group + groupProxies := make(map[string]any) + for name, proxy := range configGroupProxies { + available := make([]any, 0) + for _, c := range proxy.AvailableConnectionsString() { + available = append(available, c) + } - proxies[name] = map[string]interface{}{ - "available": available, - "busy": busy, - "total": len(available) + len(busy), + busy := make([]any, 0) + for _, conn := range proxy.BusyConnectionsString() { + busy = append(busy, conn) + } + + groupProxies[name] = map[string]any{ + "available": available, + "busy": busy, + "total": len(available) + len(busy), + } } + + proxies[configGroupName] = groupProxies } proxiesConfig, err := structpb.NewStruct(proxies) @@ -268,13 +281,14 @@ func (a *API) GetServers(context.Context, *emptypb.Empty) (*structpb.Struct, err _, span := otel.Tracer(config.TracerName).Start(a.ctx, "Get Servers") defer span.End() - servers := make(map[string]interface{}) + servers := make(map[string]any) for name, server := range a.Servers { - servers[name] = map[string]interface{}{ + servers[name] = map[string]any{ "network": server.Network, "address": server.Address, "status": uint(server.Status), "tickInterval": server.TickInterval.Nanoseconds(), + "loadBalancer": map[string]any{"strategy": server.LoadbalancerStrategyName}, } } diff --git a/api/api_helpers_test.go b/api/api_helpers_test.go index 9e383333..7e5bba9b 100644 --- a/api/api_helpers_test.go +++ b/api/api_helpers_test.go @@ -49,7 +49,7 @@ func getAPIConfig() *API { context.Background(), network.Server{ Logger: logger, - Proxy: defaultProxy, + Proxies: []network.IProxy{defaultProxy}, PluginRegistry: pluginReg, PluginTimeout: config.DefaultPluginTimeout, Network: "tcp", @@ -73,11 +73,11 @@ func getAPIConfig() *API { }, ), PluginRegistry: pluginReg, - Pools: map[string]*pool.Pool{ - config.Default: defaultPool, + Pools: map[string]map[string]*pool.Pool{ + config.Default: {config.DefaultConfigurationBlock: defaultPool}, }, - Proxies: map[string]*network.Proxy{ - config.Default: defaultProxy, + Proxies: map[string]map[string]*network.Proxy{ + config.Default: {config.DefaultConfigurationBlock: defaultProxy}, }, Servers: servers, } diff --git a/api/api_test.go b/api/api_test.go index 692a8d36..205d4975 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -210,8 +210,8 @@ func TestGetPluginsWithEmptyPluginRegistry(t *testing.T) { func TestPools(t *testing.T) { api := API{ - Pools: map[string]*pool.Pool{ - config.Default: pool.NewPool(context.TODO(), config.EmptyPoolCapacity), + Pools: map[string]map[string]*pool.Pool{ + config.Default: {config.DefaultConfigurationBlock: pool.NewPool(context.TODO(), config.EmptyPoolCapacity)}, }, ctx: context.Background(), } @@ -219,12 +219,17 @@ func TestPools(t *testing.T) { require.NoError(t, err) assert.NotEmpty(t, pools) assert.NotEmpty(t, pools.AsMap()) - assert.Equal(t, pools.AsMap()[config.Default], map[string]interface{}{"cap": 0.0, "size": 0.0}) + + assert.Equal(t, + map[string]any{ + config.DefaultConfigurationBlock: map[string]any{"cap": 0.0, "size": 0.0}, + }, + pools.AsMap()[config.Default]) } func TestPoolsWithEmptyPools(t *testing.T) { api := API{ - Pools: map[string]*pool.Pool{}, + Pools: map[string]map[string]*pool.Pool{}, ctx: context.Background(), } pools, err := api.GetPools(context.Background(), &emptypb.Empty{}) @@ -258,8 +263,8 @@ func TestGetProxies(t *testing.T) { ) api := API{ - Proxies: map[string]*network.Proxy{ - config.Default: proxy, + Proxies: map[string]map[string]*network.Proxy{ + config.Default: {config.DefaultConfigurationBlock: proxy}, }, ctx: context.Background(), } @@ -268,10 +273,14 @@ func TestGetProxies(t *testing.T) { assert.NotEmpty(t, proxies) assert.NotEmpty(t, proxies.AsMap()) - if defaultProxy, ok := proxies.AsMap()[config.Default].(map[string]interface{}); ok { - assert.Equal(t, 1.0, defaultProxy["total"]) - assert.NotEmpty(t, defaultProxy["available"]) - assert.Empty(t, defaultProxy["busy"]) + if defaultProxies, ok := proxies.AsMap()[config.Default].(map[string]any); ok { + if defaultProxy, ok := defaultProxies[config.DefaultConfigurationBlock].(map[string]any); ok { + assert.Equal(t, 1.0, defaultProxy["total"]) + assert.NotEmpty(t, defaultProxy["available"]) + assert.Empty(t, defaultProxy["busy"]) + } else { + t.Errorf("proxies.default.%s is not found or not a map", config.DefaultConfigurationBlock) + } } else { t.Errorf("proxies.default is not found or not a map") } @@ -333,20 +342,21 @@ func TestGetServers(t *testing.T) { Options: network.Option{ EnableTicker: false, }, - Proxy: proxy, - Logger: zerolog.Logger{}, - PluginRegistry: pluginRegistry, - PluginTimeout: config.DefaultPluginTimeout, - HandshakeTimeout: config.DefaultHandshakeTimeout, + Proxies: []network.IProxy{proxy}, + Logger: zerolog.Logger{}, + PluginRegistry: pluginRegistry, + PluginTimeout: config.DefaultPluginTimeout, + HandshakeTimeout: config.DefaultHandshakeTimeout, + LoadbalancerStrategyName: config.DefaultLoadBalancerStrategy, }, ) api := API{ - Pools: map[string]*pool.Pool{ - config.Default: newPool, + Pools: map[string]map[string]*pool.Pool{ + config.Default: {config.DefaultConfigurationBlock: newPool}, }, - Proxies: map[string]*network.Proxy{ - config.Default: proxy, + Proxies: map[string]map[string]*network.Proxy{ + config.Default: {config.DefaultConfigurationBlock: proxy}, }, Servers: map[string]*network.Server{ config.Default: server, @@ -361,12 +371,16 @@ func TestGetServers(t *testing.T) { if defaultServer, ok := servers.AsMap()[config.Default].(map[string]interface{}); ok { assert.Equal(t, config.DefaultNetwork, defaultServer["network"]) assert.Equal(t, config.DefaultAddress, "localhost:5432") - status, ok := defaultServer["status"].(float64) - assert.True(t, ok) - assert.Equal(t, config.Stopped, config.Status(status)) - tickInterval, ok := defaultServer["tickInterval"].(float64) - assert.True(t, ok) - assert.Equal(t, config.DefaultTickInterval.Nanoseconds(), int64(tickInterval)) + statusFloat, isStatusFloat := defaultServer["status"].(float64) + assert.True(t, isStatusFloat, "status should be of type float64") + status := config.Status(statusFloat) + assert.Equal(t, config.Stopped, status) + tickIntervalFloat, isTickIntervalFloat := defaultServer["tickInterval"].(float64) + assert.True(t, isTickIntervalFloat, "tickInterval should be of type float64") + assert.Equal(t, config.DefaultTickInterval.Nanoseconds(), int64(tickIntervalFloat)) + loadBalancerMap, isLoadBalancerMap := defaultServer["loadBalancer"].(map[string]interface{}) + assert.True(t, isLoadBalancerMap, "loadBalancer should be a map") + assert.Equal(t, config.DefaultLoadBalancerStrategy, loadBalancerMap["strategy"]) } else { t.Errorf("servers.default is not found or not a map") } diff --git a/api/healthcheck_test.go b/api/healthcheck_test.go index 2aa2a5ce..0efd03f1 100644 --- a/api/healthcheck_test.go +++ b/api/healthcheck_test.go @@ -69,7 +69,7 @@ func Test_Healthchecker(t *testing.T) { Options: network.Option{ EnableTicker: false, }, - Proxy: proxy, + Proxies: []network.IProxy{proxy}, Logger: zerolog.Logger{}, PluginRegistry: pluginRegistry, PluginTimeout: config.DefaultPluginTimeout, diff --git a/cmd/run.go b/cmd/run.go index 089ba50d..0e0d70c5 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -76,9 +76,9 @@ var ( UsageReportURL = "localhost:59091" loggers = make(map[string]zerolog.Logger) - pools = make(map[string]*pool.Pool) - clients = make(map[string]*config.Client) - proxies = make(map[string]*network.Proxy) + pools = make(map[string]map[string]*pool.Pool) + clients = make(map[string]map[string]*config.Client) + proxies = make(map[string]map[string]*network.Proxy) servers = make(map[string]*network.Server) healthCheckScheduler = gocron.NewScheduler(time.UTC) @@ -622,199 +622,209 @@ var runCmd = &cobra.Command{ _, span = otel.Tracer(config.TracerName).Start(runCtx, "Create pools and clients") // Create and initialize pools of connections. - for name, cfg := range conf.Global.Pools { - logger := loggers[name] - // Check if the pool size is greater than zero. - currentPoolSize := config.If( - cfg.Size > 0, - // Check if the pool size is greater than the minimum pool size. - config.If( - cfg.Size > config.MinimumPoolSize, - cfg.Size, - config.MinimumPoolSize, - ), - config.DefaultPoolSize, - ) - pools[name] = pool.NewPool(runCtx, currentPoolSize) + for configGroupName, configGroup := range conf.Global.Pools { + for configBlockName, cfg := range configGroup { + logger := loggers[configGroupName] + // Check if the pool size is greater than zero. + currentPoolSize := config.If( + cfg.Size > 0, + // Check if the pool size is greater than the minimum pool size. + config.If( + cfg.Size > config.MinimumPoolSize, + cfg.Size, + config.MinimumPoolSize, + ), + config.DefaultPoolSize, + ) - span.AddEvent("Create pool", trace.WithAttributes( - attribute.String("name", name), - attribute.Int("size", currentPoolSize), - )) + if _, ok := pools[configGroupName]; !ok { + pools[configGroupName] = make(map[string]*pool.Pool) + } + pools[configGroupName][configBlockName] = pool.NewPool(runCtx, currentPoolSize) - // Get client config from the config file. - if clientConfig, ok := conf.Global.Clients[name]; !ok { - // This ensures that the default client config is used if the pool name is not - // found in the clients section. - clients[name] = conf.Global.Clients[config.Default] - } else { - // Merge the default client config with the one from the pool. - clients[name] = clientConfig - } + span.AddEvent("Create pool", trace.WithAttributes( + attribute.String("name", configBlockName), + attribute.Int("size", currentPoolSize), + )) - // Fill the missing and zero values with the default ones. - clients[name].TCPKeepAlivePeriod = config.If( - clients[name].TCPKeepAlivePeriod > 0, - clients[name].TCPKeepAlivePeriod, - config.DefaultTCPKeepAlivePeriod, - ) - clients[name].ReceiveDeadline = config.If( - clients[name].ReceiveDeadline > 0, - clients[name].ReceiveDeadline, - config.DefaultReceiveDeadline, - ) - clients[name].ReceiveTimeout = config.If( - clients[name].ReceiveTimeout > 0, - clients[name].ReceiveTimeout, - config.DefaultReceiveTimeout, - ) - clients[name].SendDeadline = config.If( - clients[name].SendDeadline > 0, - clients[name].SendDeadline, - config.DefaultSendDeadline, - ) - clients[name].ReceiveChunkSize = config.If( - clients[name].ReceiveChunkSize > 0, - clients[name].ReceiveChunkSize, - config.DefaultChunkSize, - ) - clients[name].DialTimeout = config.If( - clients[name].DialTimeout > 0, - clients[name].DialTimeout, - config.DefaultDialTimeout, - ) + if _, ok := clients[configGroupName]; !ok { + clients[configGroupName] = make(map[string]*config.Client) + } - // Add clients to the pool. - for range currentPoolSize { - clientConfig := clients[name] - client := network.NewClient( - runCtx, clientConfig, logger, - network.NewRetry( - network.Retry{ - Retries: clientConfig.Retries, - Backoff: config.If( - clientConfig.Backoff > 0, - clientConfig.Backoff, - config.DefaultBackoff, - ), - BackoffMultiplier: clientConfig.BackoffMultiplier, - DisableBackoffCaps: clientConfig.DisableBackoffCaps, - Logger: loggers[name], - }, - ), + // Get client config from the config file. + if clientConfig, ok := conf.Global.Clients[configGroupName][configBlockName]; !ok { + // This ensures that the default client config is used if the pool name is not + // found in the clients section. + clients[configGroupName][configBlockName] = conf.Global.Clients[config.Default][config.DefaultConfigurationBlock] + } else { + // Merge the default client config with the one from the pool. + clients[configGroupName][configBlockName] = clientConfig + } + + // Fill the missing and zero values with the default ones. + clients[configGroupName][configBlockName].TCPKeepAlivePeriod = config.If( + clients[configGroupName][configBlockName].TCPKeepAlivePeriod > 0, + clients[configGroupName][configBlockName].TCPKeepAlivePeriod, + config.DefaultTCPKeepAlivePeriod, + ) + clients[configGroupName][configBlockName].ReceiveDeadline = config.If( + clients[configGroupName][configBlockName].ReceiveDeadline > 0, + clients[configGroupName][configBlockName].ReceiveDeadline, + config.DefaultReceiveDeadline, + ) + clients[configGroupName][configBlockName].ReceiveTimeout = config.If( + clients[configGroupName][configBlockName].ReceiveTimeout > 0, + clients[configGroupName][configBlockName].ReceiveTimeout, + config.DefaultReceiveTimeout, + ) + clients[configGroupName][configBlockName].SendDeadline = config.If( + clients[configGroupName][configBlockName].SendDeadline > 0, + clients[configGroupName][configBlockName].SendDeadline, + config.DefaultSendDeadline, + ) + clients[configGroupName][configBlockName].ReceiveChunkSize = config.If( + clients[configGroupName][configBlockName].ReceiveChunkSize > 0, + clients[configGroupName][configBlockName].ReceiveChunkSize, + config.DefaultChunkSize, + ) + clients[configGroupName][configBlockName].DialTimeout = config.If( + clients[configGroupName][configBlockName].DialTimeout > 0, + clients[configGroupName][configBlockName].DialTimeout, + config.DefaultDialTimeout, ) - if client != nil { - eventOptions := trace.WithAttributes( - attribute.String("name", name), - attribute.String("network", client.Network), - attribute.String("address", client.Address), - attribute.Int("receiveChunkSize", client.ReceiveChunkSize), - attribute.String("receiveDeadline", client.ReceiveDeadline.String()), - attribute.String("receiveTimeout", client.ReceiveTimeout.String()), - attribute.String("sendDeadline", client.SendDeadline.String()), - attribute.String("dialTimeout", client.DialTimeout.String()), - attribute.Bool("tcpKeepAlive", client.TCPKeepAlive), - attribute.String("tcpKeepAlivePeriod", client.TCPKeepAlivePeriod.String()), - attribute.String("localAddress", client.LocalAddr()), - attribute.String("remoteAddress", client.RemoteAddr()), - attribute.Int("retries", clientConfig.Retries), - attribute.String("backoff", client.Retry().Backoff.String()), - attribute.Float64("backoffMultiplier", clientConfig.BackoffMultiplier), - attribute.Bool("disableBackoffCaps", clientConfig.DisableBackoffCaps), + // Add clients to the pool. + for range currentPoolSize { + clientConfig := clients[configGroupName][configBlockName] + client := network.NewClient( + runCtx, clientConfig, logger, + network.NewRetry( + network.Retry{ + Retries: clientConfig.Retries, + Backoff: config.If( + clientConfig.Backoff > 0, + clientConfig.Backoff, + config.DefaultBackoff, + ), + BackoffMultiplier: clientConfig.BackoffMultiplier, + DisableBackoffCaps: clientConfig.DisableBackoffCaps, + Logger: loggers[configBlockName], + }, + ), ) - if client.ID != "" { - eventOptions = trace.WithAttributes( - attribute.String("id", client.ID), - ) - } - - span.AddEvent("Create client", eventOptions) - - pluginTimeoutCtx, cancel = context.WithTimeout( - context.Background(), conf.Plugin.Timeout) - defer cancel() - - clientCfg := map[string]interface{}{ - "id": client.ID, - "network": client.Network, - "address": client.Address, - "receiveChunkSize": client.ReceiveChunkSize, - "receiveDeadline": client.ReceiveDeadline.String(), - "receiveTimeout": client.ReceiveTimeout.String(), - "sendDeadline": client.SendDeadline.String(), - "dialTimeout": client.DialTimeout.String(), - "tcpKeepAlive": client.TCPKeepAlive, - "tcpKeepAlivePeriod": client.TCPKeepAlivePeriod.String(), - "localAddress": client.LocalAddr(), - "remoteAddress": client.RemoteAddr(), - "retries": clientConfig.Retries, - "backoff": client.Retry().Backoff.String(), - "backoffMultiplier": clientConfig.BackoffMultiplier, - "disableBackoffCaps": clientConfig.DisableBackoffCaps, - } - _, err := pluginRegistry.Run( - pluginTimeoutCtx, clientCfg, v1.HookName_HOOK_NAME_ON_NEW_CLIENT) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnNewClient hooks") - span.RecordError(err) - } - err = pools[name].Put(client.ID, client) - if err != nil { - logger.Error().Err(err).Msg("Failed to add client to the pool") - span.RecordError(err) + if client != nil { + eventOptions := trace.WithAttributes( + attribute.String("name", configBlockName), + attribute.String("network", client.Network), + attribute.String("address", client.Address), + attribute.Int("receiveChunkSize", client.ReceiveChunkSize), + attribute.String("receiveDeadline", client.ReceiveDeadline.String()), + attribute.String("receiveTimeout", client.ReceiveTimeout.String()), + attribute.String("sendDeadline", client.SendDeadline.String()), + attribute.String("dialTimeout", client.DialTimeout.String()), + attribute.Bool("tcpKeepAlive", client.TCPKeepAlive), + attribute.String("tcpKeepAlivePeriod", client.TCPKeepAlivePeriod.String()), + attribute.String("localAddress", client.LocalAddr()), + attribute.String("remoteAddress", client.RemoteAddr()), + attribute.Int("retries", clientConfig.Retries), + attribute.String("backoff", client.Retry().Backoff.String()), + attribute.Float64("backoffMultiplier", clientConfig.BackoffMultiplier), + attribute.Bool("disableBackoffCaps", clientConfig.DisableBackoffCaps), + ) + if client.ID != "" { + eventOptions = trace.WithAttributes( + attribute.String("id", client.ID), + ) + } + + span.AddEvent("Create client", eventOptions) + + pluginTimeoutCtx, cancel = context.WithTimeout( + context.Background(), conf.Plugin.Timeout) + defer cancel() + + clientCfg := map[string]interface{}{ + "id": client.ID, + "network": client.Network, + "address": client.Address, + "receiveChunkSize": client.ReceiveChunkSize, + "receiveDeadline": client.ReceiveDeadline.String(), + "receiveTimeout": client.ReceiveTimeout.String(), + "sendDeadline": client.SendDeadline.String(), + "dialTimeout": client.DialTimeout.String(), + "tcpKeepAlive": client.TCPKeepAlive, + "tcpKeepAlivePeriod": client.TCPKeepAlivePeriod.String(), + "localAddress": client.LocalAddr(), + "remoteAddress": client.RemoteAddr(), + "retries": clientConfig.Retries, + "backoff": client.Retry().Backoff.String(), + "backoffMultiplier": clientConfig.BackoffMultiplier, + "disableBackoffCaps": clientConfig.DisableBackoffCaps, + } + _, err := pluginRegistry.Run( + pluginTimeoutCtx, clientCfg, v1.HookName_HOOK_NAME_ON_NEW_CLIENT) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnNewClient hooks") + span.RecordError(err) + } + + err = pools[configGroupName][configBlockName].Put(client.ID, client) + if err != nil { + logger.Error().Err(err).Msg("Failed to add client to the pool") + span.RecordError(err) + } + } else { + logger.Error().Msg("Failed to create client, please check the configuration") + go func() { + // Wait for the stop signal to exit gracefully. + // This prevents the program from waiting indefinitely + // after the StopGracefully function is called. + <-stopChan + os.Exit(gerr.FailedToCreateClient) + }() + StopGracefully( + runCtx, + nil, + metricsMerger, + metricsServer, + pluginRegistry, + logger, + servers, + stopChan, + httpServer, + grpcServer, + ) } - } else { - logger.Error().Msg("Failed to create client, please check the configuration") - go func() { - // Wait for the stop signal to exit gracefully. - // This prevents the program from waiting indefinitely - // after the StopGracefully function is called. - <-stopChan - os.Exit(gerr.FailedToCreateClient) - }() - StopGracefully( - runCtx, - nil, - metricsMerger, - metricsServer, - pluginRegistry, - logger, - servers, - stopChan, - httpServer, - grpcServer, - ) } - } - // Verify that the pool is properly populated. - logger.Info().Fields(map[string]interface{}{ - "name": name, - "count": strconv.Itoa(pools[name].Size()), - }).Msg("There are clients available in the pool") - - if pools[name].Size() != currentPoolSize { - logger.Error().Msg( - "The pool size is incorrect, either because " + - "the clients cannot connect due to no network connectivity " + - "or the server is not running. exiting...") - pluginRegistry.Shutdown() - os.Exit(gerr.FailedToInitializePool) - } + // Verify that the pool is properly populated. + logger.Info().Fields(map[string]interface{}{ + "name": configBlockName, + "count": strconv.Itoa(pools[configGroupName][configBlockName].Size()), + }).Msg("There are clients available in the pool") + + if pools[configGroupName][configBlockName].Size() != currentPoolSize { + logger.Error().Msg( + "The pool size is incorrect, either because " + + "the clients cannot connect due to no network connectivity " + + "or the server is not running. exiting...") + pluginRegistry.Shutdown() + os.Exit(gerr.FailedToInitializePool) + } - pluginTimeoutCtx, cancel = context.WithTimeout( - context.Background(), conf.Plugin.Timeout) - defer cancel() + pluginTimeoutCtx, cancel = context.WithTimeout( + context.Background(), conf.Plugin.Timeout) + defer cancel() - _, err = pluginRegistry.Run( - pluginTimeoutCtx, - map[string]interface{}{"name": name, "size": currentPoolSize}, - v1.HookName_HOOK_NAME_ON_NEW_POOL) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnNewPool hooks") - span.RecordError(err) + _, err = pluginRegistry.Run( + pluginTimeoutCtx, + map[string]interface{}{"name": configBlockName, "size": currentPoolSize}, + v1.HookName_HOOK_NAME_ON_NEW_POOL) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnNewPool hooks") + span.RecordError(err) + } } } @@ -822,46 +832,54 @@ var runCmd = &cobra.Command{ _, span = otel.Tracer(config.TracerName).Start(runCtx, "Create proxies") // Create and initialize prefork proxies with each pool of clients. - for name, cfg := range conf.Global.Proxies { - logger := loggers[name] - clientConfig := clients[name] - // Fill the missing and zero value with the default one. - cfg.HealthCheckPeriod = config.If( - cfg.HealthCheckPeriod > 0, - cfg.HealthCheckPeriod, - config.DefaultHealthCheckPeriod, - ) + for configGroupName, configGroup := range conf.Global.Proxies { + for configBlockName, cfg := range configGroup { + logger := loggers[configGroupName] + clientConfig := clients[configGroupName][configBlockName] + + // Fill the missing and zero value with the default one. + cfg.HealthCheckPeriod = config.If( + cfg.HealthCheckPeriod > 0, + cfg.HealthCheckPeriod, + config.DefaultHealthCheckPeriod, + ) - proxies[name] = network.NewProxy( - runCtx, - network.Proxy{ - AvailableConnections: pools[name], - PluginRegistry: pluginRegistry, - HealthCheckPeriod: cfg.HealthCheckPeriod, - ClientConfig: clientConfig, - Logger: logger, - PluginTimeout: conf.Plugin.Timeout, - }, - ) + if _, ok := proxies[configGroupName]; !ok { + proxies[configGroupName] = make(map[string]*network.Proxy) + } - span.AddEvent("Create proxy", trace.WithAttributes( - attribute.String("name", name), - attribute.String("healthCheckPeriod", cfg.HealthCheckPeriod.String()), - )) + proxies[configGroupName][configBlockName] = network.NewProxy( + runCtx, + network.Proxy{ + AvailableConnections: pools[configGroupName][configBlockName], + PluginRegistry: pluginRegistry, + HealthCheckPeriod: cfg.HealthCheckPeriod, + ClientConfig: clientConfig, + Logger: logger, + PluginTimeout: conf.Plugin.Timeout, + }, + ) - pluginTimeoutCtx, cancel = context.WithTimeout( - context.Background(), conf.Plugin.Timeout) - defer cancel() + span.AddEvent("Create proxy", trace.WithAttributes( + attribute.String("name", configBlockName), + attribute.String("healthCheckPeriod", cfg.HealthCheckPeriod.String()), + )) - if data, ok := conf.GlobalKoanf.Get("proxies").(map[string]interface{}); ok { - _, err = pluginRegistry.Run( - pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_PROXY) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks") - span.RecordError(err) + pluginTimeoutCtx, cancel = context.WithTimeout( + context.Background(), conf.Plugin.Timeout) + defer cancel() + + if data, ok := conf.GlobalKoanf.Get("proxies").(map[string]interface{}); ok { + _, err = pluginRegistry.Run( + pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_PROXY) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks") + span.RecordError(err) + } + } else { + logger.Error().Msg("Failed to get proxy from config") } - } else { - logger.Error().Msg("Failed to get proxy from config") + } } @@ -871,6 +889,12 @@ var runCmd = &cobra.Command{ // Create and initialize servers. for name, cfg := range conf.Global.Servers { logger := loggers[name] + + var serverProxies []network.IProxy + for _, proxy := range proxies[name] { + serverProxies = append(serverProxies, proxy) + } + servers[name] = network.NewServer( runCtx, network.Server{ @@ -885,14 +909,15 @@ var runCmd = &cobra.Command{ // Can be used to send keepalive messages to the client. EnableTicker: cfg.EnableTicker, }, - Proxy: proxies[name], - Logger: logger, - PluginRegistry: pluginRegistry, - PluginTimeout: conf.Plugin.Timeout, - EnableTLS: cfg.EnableTLS, - CertFile: cfg.CertFile, - KeyFile: cfg.KeyFile, - HandshakeTimeout: cfg.HandshakeTimeout, + Proxies: serverProxies, + Logger: logger, + PluginRegistry: pluginRegistry, + PluginTimeout: conf.Plugin.Timeout, + EnableTLS: cfg.EnableTLS, + CertFile: cfg.CertFile, + KeyFile: cfg.KeyFile, + HandshakeTimeout: cfg.HandshakeTimeout, + LoadbalancerStrategyName: cfg.LoadBalancer.Strategy, }, ) diff --git a/cmd/testdata/gatewayd.yaml b/cmd/testdata/gatewayd.yaml index 30636f12..4a9f016a 100644 --- a/cmd/testdata/gatewayd.yaml +++ b/cmd/testdata/gatewayd.yaml @@ -18,21 +18,52 @@ metrics: clients: default: - address: localhost:5432 + activeWrites: + address: localhost:5432 + network: tcp + tcpKeepAlive: False + tcpKeepAlivePeriod: 30s # duration + receiveChunkSize: 8192 + receiveDeadline: 0s # duration, 0ms/0s means no deadline + receiveTimeout: 0s # duration, 0ms/0s means no timeout + sendDeadline: 0s # duration, 0ms/0s means no deadline + dialTimeout: 60s # duration + # Retry configuration + retries: 3 # 0 means no retry and fail immediately on the first attempt + backoff: 1s # duration + backoffMultiplier: 2.0 # 0 means no backoff + disableBackoffCaps: false test: - address: localhost:5433 + write: + address: localhost:5433 + network: tcp + tcpKeepAlive: False + tcpKeepAlivePeriod: 30s # duration + receiveChunkSize: 8192 + receiveDeadline: 0s # duration, 0ms/0s means no deadline + receiveTimeout: 0s # duration, 0ms/0s means no timeout + sendDeadline: 0s # duration, 0ms/0s means no deadline + dialTimeout: 60s # duration + retries: 3 # 0 means no retry and fail immediately on the first attempt + backoff: 1s # duration + backoffMultiplier: 2.0 # 0 means no backoff + disableBackoffCaps: false pools: default: - size: 10 + activeWrites: + size: 10 test: - size: 10 + write: + size: 10 proxies: default: - healthCheckPeriod: 60s # duration + activeWrites: + healthCheckPeriod: 60s # duration test: - healthCheckPeriod: 60s # duration + write: + healthCheckPeriod: 60s # duration servers: default: diff --git a/cmd/testdata/gatewayd_tls.yaml b/cmd/testdata/gatewayd_tls.yaml index 75f584bc..f1c99f0c 100644 --- a/cmd/testdata/gatewayd_tls.yaml +++ b/cmd/testdata/gatewayd_tls.yaml @@ -12,15 +12,30 @@ metrics: clients: default: - address: localhost:5432 + activeWrites: + address: localhost:5432 + network: tcp + tcpKeepAlive: False + tcpKeepAlivePeriod: 30s + receiveChunkSize: 8192 + receiveDeadline: 0s + receiveTimeout: 0s + sendDeadline: 0s + dialTimeout: 60s + retries: 3 + backoff: 1s + backoffMultiplier: 2.0 + disableBackoffCaps: false pools: default: - size: 10 + activeWrites: + size: 10 proxies: default: - healthCheckPeriod: 60s # duration + activeWrites: + healthCheckPeriod: 60s # duration servers: default: diff --git a/config/config.go b/config/config.go index 0903fe58..7401cc69 100644 --- a/config/config.go +++ b/config/config.go @@ -160,14 +160,15 @@ func (c *Config) LoadDefaults(ctx context.Context) *gerr.GatewayDError { CertFile: "", KeyFile: "", HandshakeTimeout: DefaultHandshakeTimeout, + LoadBalancer: LoadBalancer{Strategy: DefaultLoadBalancerStrategy}, } c.globalDefaults = GlobalConfig{ Loggers: map[string]*Logger{Default: &defaultLogger}, Metrics: map[string]*Metrics{Default: &defaultMetric}, - Clients: map[string]*Client{Default: &defaultClient}, - Pools: map[string]*Pool{Default: &defaultPool}, - Proxies: map[string]*Proxy{Default: &defaultProxy}, + Clients: map[string]map[string]*Client{Default: {DefaultConfigurationBlock: &defaultClient}}, + Pools: map[string]map[string]*Pool{Default: {DefaultConfigurationBlock: &defaultPool}}, + Proxies: map[string]map[string]*Proxy{Default: {DefaultConfigurationBlock: &defaultProxy}}, Servers: map[string]*Server{Default: &defaultServer}, API: API{ Enabled: true, @@ -188,27 +189,59 @@ func (c *Config) LoadDefaults(ctx context.Context) *gerr.GatewayDError { } for configObject, configMap := range gconf { - if configGroup, ok := configMap.(map[string]interface{}); ok { - for configGroupKey := range configGroup { - if configGroupKey == Default { + configGroup, ok := configMap.(map[string]any) + if !ok { + err := fmt.Errorf("invalid config structure for %s", configObject) + span.RecordError(err) + span.End() + return gerr.ErrConfigParseError.Wrap(err) + } + + if configObject == "api" { + // Handle API configuration separately + // TODO: Add support for multiple API config groups. + continue + } + + for configGroupKey, configBlocksInterface := range configGroup { + if configGroupKey == Default { + continue + } + + configBlocks, ok := configBlocksInterface.(map[string]any) + if !ok { + err := fmt.Errorf("invalid config blocks structure for %s.%s", configObject, configGroupKey) + span.RecordError(err) + span.End() + return gerr.ErrConfigParseError.Wrap(err) + } + + for configBlockKey := range configBlocks { + if configBlockKey == DefaultConfigurationBlock { continue } - switch configObject { case "loggers": c.globalDefaults.Loggers[configGroupKey] = &defaultLogger case "metrics": c.globalDefaults.Metrics[configGroupKey] = &defaultMetric case "clients": - c.globalDefaults.Clients[configGroupKey] = &defaultClient + if c.globalDefaults.Clients[configGroupKey] == nil { + c.globalDefaults.Clients[configGroupKey] = make(map[string]*Client) + } + c.globalDefaults.Clients[configGroupKey][configBlockKey] = &defaultClient case "pools": - c.globalDefaults.Pools[configGroupKey] = &defaultPool + if c.globalDefaults.Pools[configGroupKey] == nil { + c.globalDefaults.Pools[configGroupKey] = make(map[string]*Pool) + } + c.globalDefaults.Pools[configGroupKey][configBlockKey] = &defaultPool case "proxies": - c.globalDefaults.Proxies[configGroupKey] = &defaultProxy + if c.globalDefaults.Proxies[configGroupKey] == nil { + c.globalDefaults.Proxies[configGroupKey] = make(map[string]*Proxy) + } + c.globalDefaults.Proxies[configGroupKey][configBlockKey] = &defaultProxy case "servers": c.globalDefaults.Servers[configGroupKey] = &defaultServer - case "api": - // TODO: Add support for multiple API config groups. default: err := fmt.Errorf("unknown config object: %s", configObject) span.RecordError(err) @@ -441,11 +474,18 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError { seenConfigObjects = append(seenConfigObjects, "metrics") } - for configGroup := range globalConfig.Clients { - if globalConfig.Clients[configGroup] == nil { - err := fmt.Errorf("\"clients.%s\" is nil or empty", configGroup) - span.RecordError(err) - errors = append(errors, gerr.ErrValidationFailed.Wrap(err)) + clientConfigGroups := make(map[string]map[string]bool) + for configGroupName, configGroups := range globalConfig.Clients { + if _, ok := clientConfigGroups[configGroupName]; !ok { + clientConfigGroups[configGroupName] = make(map[string]bool) + } + for configGroup := range configGroups { + clientConfigGroups[configGroupName][configGroup] = true + if globalConfig.Clients[configGroupName][configGroup] == nil { + err := fmt.Errorf("\"clients.%s\" is nil or empty", configGroup) + span.RecordError(err) + errors = append(errors, gerr.ErrValidationFailed.Wrap(err)) + } } } @@ -467,7 +507,7 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError { for configGroup := range globalConfig.Proxies { if globalConfig.Proxies[configGroup] == nil { - err := fmt.Errorf("\"proxies.%s\" is nil or empty", configGroup) + err := fmt.Errorf(`"proxies.%s" is nil or empty`, configGroup) span.RecordError(err) errors = append(errors, gerr.ErrValidationFailed.Wrap(err)) } @@ -489,6 +529,36 @@ func (c *Config) ValidateGlobalConfig(ctx context.Context) *gerr.GatewayDError { seenConfigObjects = append(seenConfigObjects, "servers") } + // ValidateClientsPoolsProxies checks if all configGroups in globalConfig.Pools and globalConfig.Proxies + // are referenced in globalConfig.Clients. + if len(globalConfig.Clients) != len(globalConfig.Pools) || len(globalConfig.Clients) != len(globalConfig.Proxies) { + err := goerrors.New("clients, pools, and proxies do not have the same number of objects") + span.RecordError(err) + errors = append(errors, gerr.ErrValidationFailed.Wrap(err)) + } + + // Check if all proxies are referenced in client configuration + for configGroupName, configGroups := range globalConfig.Proxies { + for configGroup := range configGroups { + if !clientConfigGroups[configGroupName][configGroup] { + err := fmt.Errorf(`"proxies.%s" not referenced in client configuration`, configGroup) + span.RecordError(err) + errors = append(errors, gerr.ErrValidationFailed.Wrap(err)) + } + } + } + + // Check if all pools are referenced in client configuration + for configGroupName, configGroups := range globalConfig.Pools { + for configGroup := range configGroups { + if !clientConfigGroups[configGroupName][configGroup] { + err := fmt.Errorf(`"pools.%s" not referenced in client configuration`, configGroup) + span.RecordError(err) + errors = append(errors, gerr.ErrValidationFailed.Wrap(err)) + } + } + } + sort.Strings(seenConfigObjects) if len(seenConfigObjects) > 0 && !reflect.DeepEqual(configObjects, seenConfigObjects) { diff --git a/config/constants.go b/config/constants.go index 4591de2c..e7a7c475 100644 --- a/config/constants.go +++ b/config/constants.go @@ -34,11 +34,12 @@ const ( const ( // Config constants. - Default = "default" - EnvPrefix = "GATEWAYD_" - TracerName = "gatewayd" - GlobalConfigFilename = "gatewayd.yaml" - PluginsConfigFilename = "gatewayd_plugins.yaml" + Default = "default" + DefaultConfigurationBlock = "activeWrites" + EnvPrefix = "GATEWAYD_" + TracerName = "gatewayd" + GlobalConfigFilename = "gatewayd.yaml" + PluginsConfigFilename = "gatewayd_plugins.yaml" // Logger constants. DefaultLogOutput = "console" @@ -89,10 +90,11 @@ const ( DefaultHealthCheckPeriod = 60 * time.Second // This must match PostgreSQL authentication timeout. // Server constants. - DefaultListenNetwork = "tcp" - DefaultListenAddress = "0.0.0.0:15432" - DefaultTickInterval = 5 * time.Second - DefaultHandshakeTimeout = 5 * time.Second + DefaultListenNetwork = "tcp" + DefaultListenAddress = "0.0.0.0:15432" + DefaultTickInterval = 5 * time.Second + DefaultHandshakeTimeout = 5 * time.Second + DefaultLoadBalancerStrategy = "ROUND_ROBIN" // Utility constants. DefaultSeed = 1000 @@ -124,3 +126,8 @@ const ( DefaultRedisAddress = "localhost:6379" DefaultRedisChannel = "gatewayd-actions" ) + +// Load balancing strategies. +const ( + RoundRobinStrategy = "ROUND_ROBIN" +) diff --git a/config/getters.go b/config/getters.go index 406ec208..02d1cb7f 100644 --- a/config/getters.go +++ b/config/getters.go @@ -114,9 +114,9 @@ func (gc GlobalConfig) Filter(groupName string) *GlobalConfig { } return &GlobalConfig{ Loggers: map[string]*Logger{groupName: gc.Loggers[groupName]}, - Clients: map[string]*Client{groupName: gc.Clients[groupName]}, - Pools: map[string]*Pool{groupName: gc.Pools[groupName]}, - Proxies: map[string]*Proxy{groupName: gc.Proxies[groupName]}, + Clients: map[string]map[string]*Client{groupName: gc.Clients[groupName]}, + Pools: map[string]map[string]*Pool{groupName: gc.Pools[groupName]}, + Proxies: map[string]map[string]*Proxy{groupName: gc.Proxies[groupName]}, Servers: map[string]*Server{groupName: gc.Servers[groupName]}, Metrics: map[string]*Metrics{groupName: gc.Metrics[groupName]}, API: gc.API, diff --git a/config/testdata/missing_keys.yaml b/config/testdata/missing_keys.yaml index d2311a7d..0d8eb280 100644 --- a/config/testdata/missing_keys.yaml +++ b/config/testdata/missing_keys.yaml @@ -15,25 +15,32 @@ metrics: default: enabled: True test: - enabled: True + write: + enabled: True clients: default: - address: localhost:5432 + activeWrites: + address: localhost:5432 test: - address: localhost:5433 + write: + address: localhost:5433 pools: default: - size: 10 + activeWrites: + size: 10 test: - size: 10 + write: + size: 10 proxies: default: - healthCheckPeriod: 60s # duration + activeWrites: + healthCheckPeriod: 60s # duration test: - healthCheckPeriod: 60s # duration + write: + healthCheckPeriod: 60s # duration servers: default: diff --git a/config/types.go b/config/types.go index 7a49fd63..f08e3787 100644 --- a/config/types.go +++ b/config/types.go @@ -44,19 +44,19 @@ type ActionRedisConfig struct { } type Client struct { - Network string `json:"network" jsonschema:"enum=tcp,enum=udp,enum=unix"` - Address string `json:"address"` - TCPKeepAlive bool `json:"tcpKeepAlive"` - TCPKeepAlivePeriod time.Duration `json:"tcpKeepAlivePeriod" jsonschema:"oneof_type=string;integer"` - ReceiveChunkSize int `json:"receiveChunkSize"` - ReceiveDeadline time.Duration `json:"receiveDeadline" jsonschema:"oneof_type=string;integer"` - ReceiveTimeout time.Duration `json:"receiveTimeout" jsonschema:"oneof_type=string;integer"` - SendDeadline time.Duration `json:"sendDeadline" jsonschema:"oneof_type=string;integer"` - DialTimeout time.Duration `json:"dialTimeout" jsonschema:"oneof_type=string;integer"` - Retries int `json:"retries"` - Backoff time.Duration `json:"backoff" jsonschema:"oneof_type=string;integer"` - BackoffMultiplier float64 `json:"backoffMultiplier"` - DisableBackoffCaps bool `json:"disableBackoffCaps"` + Network string `json:"network" jsonschema:"enum=tcp,enum=udp,enum=unix" yaml:"network"` + Address string `json:"address" yaml:"address"` + TCPKeepAlive bool `json:"tcpKeepAlive" yaml:"tcpKeepAlive"` + TCPKeepAlivePeriod time.Duration `json:"tcpKeepAlivePeriod" jsonschema:"oneof_type=string;integer" yaml:"tcpKeepAlivePeriod"` + ReceiveChunkSize int `json:"receiveChunkSize" yaml:"receiveChunkSize"` + ReceiveDeadline time.Duration `json:"receiveDeadline" jsonschema:"oneof_type=string;integer" yaml:"receiveDeadline"` + ReceiveTimeout time.Duration `json:"receiveTimeout" jsonschema:"oneof_type=string;integer" yaml:"receiveTimeout"` + SendDeadline time.Duration `json:"sendDeadline" jsonschema:"oneof_type=string;integer" yaml:"sendDeadline"` + DialTimeout time.Duration `json:"dialTimeout" jsonschema:"oneof_type=string;integer" yaml:"dialTimeout"` + Retries int `json:"retries" yaml:"retries"` + Backoff time.Duration `json:"backoff" jsonschema:"oneof_type=string;integer" yaml:"backoff"` + BackoffMultiplier float64 `json:"backoffMultiplier" yaml:"backoffMultiplier"` + DisableBackoffCaps bool `json:"disableBackoffCaps" yaml:"disableBackoffCaps"` } type Logger struct { @@ -89,11 +89,15 @@ type Metrics struct { } type Pool struct { - Size int `json:"size"` + Size int `json:"size" yaml:"size"` } type Proxy struct { - HealthCheckPeriod time.Duration `json:"healthCheckPeriod" jsonschema:"oneof_type=string;integer"` + HealthCheckPeriod time.Duration `json:"healthCheckPeriod" jsonschema:"oneof_type=string;integer" yaml:"healthCheckPeriod"` +} + +type LoadBalancer struct { + Strategy string `json:"strategy"` } type Server struct { @@ -105,6 +109,7 @@ type Server struct { CertFile string `json:"certFile"` KeyFile string `json:"keyFile"` HandshakeTimeout time.Duration `json:"handshakeTimeout" jsonschema:"oneof_type=string;integer"` + LoadBalancer LoadBalancer `json:"loadBalancer"` } type API struct { @@ -115,11 +120,11 @@ type API struct { } type GlobalConfig struct { - API API `json:"api"` - Loggers map[string]*Logger `json:"loggers"` - Clients map[string]*Client `json:"clients"` - Pools map[string]*Pool `json:"pools"` - Proxies map[string]*Proxy `json:"proxies"` - Servers map[string]*Server `json:"servers"` - Metrics map[string]*Metrics `json:"metrics"` + API API `json:"api"` + Loggers map[string]*Logger `json:"loggers"` + Clients map[string]map[string]*Client `json:"clients"` + Pools map[string]map[string]*Pool `json:"pools"` + Proxies map[string]map[string]*Proxy `json:"proxies"` + Servers map[string]*Server `json:"servers"` + Metrics map[string]*Metrics `json:"metrics"` } diff --git a/errors/errors.go b/errors/errors.go index c9868159..f4d8bc43 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -53,6 +53,8 @@ const ( ErrCodeMsgEncodeError ErrCodeConfigParseError ErrCodePublishAsyncAction + ErrCodeLoadBalancerStrategyNotFound + ErrCodeNoProxiesAvailable ) var ( @@ -194,6 +196,14 @@ var ( ErrCodePublishAsyncAction, "error publishing async action", nil, } + ErrLoadBalancerStrategyNotFound = &GatewayDError{ + ErrCodeLoadBalancerStrategyNotFound, "The specified load balancer strategy does not exist.", nil, + } + + ErrNoProxiesAvailable = &GatewayDError{ + ErrCodeNoProxiesAvailable, "No proxies available to select.", nil, + } + // Unwrapped errors. ErrLoggerRequired = errors.New("terminate action requires a logger parameter") ) diff --git a/gatewayd.yaml b/gatewayd.yaml index d407f1a7..47c03d5a 100644 --- a/gatewayd.yaml +++ b/gatewayd.yaml @@ -32,33 +32,58 @@ metrics: clients: default: - network: tcp - address: localhost:5432 - tcpKeepAlive: False - tcpKeepAlivePeriod: 30s # duration - receiveChunkSize: 8192 - receiveDeadline: 0s # duration, 0ms/0s means no deadline - receiveTimeout: 0s # duration, 0ms/0s means no timeout - sendDeadline: 0s # duration, 0ms/0s means no deadline - dialTimeout: 60s # duration - # Retry configuration - retries: 3 # 0 means no retry and fail immediately on the first attempt - backoff: 1s # duration - backoffMultiplier: 2.0 # 0 means no backoff - disableBackoffCaps: false + activeWrites: + network: tcp + address: localhost:5432 + tcpKeepAlive: False + tcpKeepAlivePeriod: 30s # duration + receiveChunkSize: 8192 + receiveDeadline: 0s # duration, 0ms/0s means no deadline + receiveTimeout: 0s # duration, 0ms/0s means no timeout + sendDeadline: 0s # duration, 0ms/0s means no deadline + dialTimeout: 60s # duration + # Retry configuration + retries: 3 # 0 means no retry and fail immediately on the first attempt + backoff: 1s # duration + backoffMultiplier: 2.0 # 0 means no backoff + disableBackoffCaps: false + standbyReads: + network: tcp + address: localhost:5433 + tcpKeepAlive: False + tcpKeepAlivePeriod: 30s # duration + receiveChunkSize: 8192 + receiveDeadline: 0s # duration, 0ms/0s means no deadline + receiveTimeout: 0s # duration, 0ms/0s means no timeout + sendDeadline: 0s # duration, 0ms/0s means no deadline + dialTimeout: 60s # duration + # Retry configuration + retries: 3 # 0 means no retry and fail immediately on the first attempt + backoff: 1s # duration + backoffMultiplier: 2.0 # 0 means no backoff + disableBackoffCaps: false pools: default: - size: 10 + activeWrites: + size: 10 + standbyReads: + size: 10 proxies: default: - healthCheckPeriod: 60s # duration + activeWrites: + healthCheckPeriod: 60s # duration + standbyReads: + healthCheckPeriod: 60s # duration servers: default: network: tcp address: 0.0.0.0:15432 + loadBalancer: + # Load balancer strategies can be found in config/constants.go + strategy: ROUND_ROBIN enableTicker: False tickInterval: 5s # duration enableTLS: False diff --git a/network/loadbalancer.go b/network/loadbalancer.go new file mode 100644 index 00000000..76c57d2b --- /dev/null +++ b/network/loadbalancer.go @@ -0,0 +1,19 @@ +package network + +import ( + "github.com/gatewayd-io/gatewayd/config" + gerr "github.com/gatewayd-io/gatewayd/errors" +) + +type LoadBalancerStrategy interface { + NextProxy() (IProxy, *gerr.GatewayDError) +} + +func NewLoadBalancerStrategy(server *Server) (LoadBalancerStrategy, *gerr.GatewayDError) { + switch server.LoadbalancerStrategyName { + case config.RoundRobinStrategy: + return NewRoundRobin(server), nil + default: + return nil, gerr.ErrLoadBalancerStrategyNotFound + } +} diff --git a/network/loadbalancer_test.go b/network/loadbalancer_test.go new file mode 100644 index 00000000..1d588b2a --- /dev/null +++ b/network/loadbalancer_test.go @@ -0,0 +1,44 @@ +package network + +import ( + "errors" + "testing" + + "github.com/gatewayd-io/gatewayd/config" + gerr "github.com/gatewayd-io/gatewayd/errors" +) + +// TestNewLoadBalancerStrategy tests the NewLoadBalancerStrategy function to ensure it correctly +// initializes the load balancer strategy based on the strategy name provided in the server configuration. +// It covers both valid and invalid strategy names. +func TestNewLoadBalancerStrategy(t *testing.T) { + serverValid := &Server{ + LoadbalancerStrategyName: config.RoundRobinStrategy, + Proxies: []IProxy{MockProxy{}}, + } + + // Test case 1: Valid strategy name + strategy, err := NewLoadBalancerStrategy(serverValid) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + _, ok := strategy.(*RoundRobin) + if !ok { + t.Errorf("Expected strategy to be of type RoundRobin") + } + + // Test case 2: InValid strategy name + serverInvalid := &Server{ + LoadbalancerStrategyName: "InvalidStrategy", + Proxies: []IProxy{MockProxy{}}, + } + + strategy, err = NewLoadBalancerStrategy(serverInvalid) + if !errors.Is(err, gerr.ErrLoadBalancerStrategyNotFound) { + t.Errorf("Expected ErrLoadBalancerStrategyNotFound, got %v", err) + } + if strategy != nil { + t.Errorf("Expected strategy to be nil for invalid strategy name") + } +} diff --git a/network/network_helpers_test.go b/network/network_helpers_test.go index 618cade2..e8e44cf5 100644 --- a/network/network_helpers_test.go +++ b/network/network_helpers_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" @@ -16,6 +17,11 @@ type WriteBuffer struct { msgStart int } +// MockProxy implements the IProxy interface for testing purposes. +type MockProxy struct { + name string +} + // writeStartupMsg writes a PostgreSQL startup message to the buffer. func writeStartupMsg(buf *WriteBuffer, user, database, appName string) { // Write startup message header @@ -154,3 +160,51 @@ func CollectAndComparePrometheusMetrics(t *testing.T) { require.NoError(t, testutil.GatherAndCompare(prometheus.DefaultGatherer, strings.NewReader(want), metrics...)) } + +// Connect is a mock implementation of the Connect method in the IProxy interface. +func (m MockProxy) Connect(_ *ConnWrapper) *gerr.GatewayDError { + return nil +} + +// Disconnect is a mock implementation of the Disconnect method in the IProxy interface. +func (m MockProxy) Disconnect(_ *ConnWrapper) *gerr.GatewayDError { + return nil +} + +// PassThroughToServer is a mock implementation of the PassThroughToServer method in the IProxy interface. +func (m MockProxy) PassThroughToServer(_ *ConnWrapper, _ *Stack) *gerr.GatewayDError { + return nil +} + +// PassThroughToClient is a mock implementation of the PassThroughToClient method in the IProxy interface. +func (m MockProxy) PassThroughToClient(_ *ConnWrapper, _ *Stack) *gerr.GatewayDError { + return nil +} + +// IsHealthy is a mock implementation of the IsHealthy method in the IProxy interface. +func (m MockProxy) IsHealthy(_ *Client) (*Client, *gerr.GatewayDError) { + return nil, nil +} + +// IsExhausted is a mock implementation of the IsExhausted method in the IProxy interface. +func (m MockProxy) IsExhausted() bool { + return false +} + +// Shutdown is a mock implementation of the Shutdown method in the IProxy interface. +func (m MockProxy) Shutdown() {} + +// AvailableConnectionsString is a mock implementation of the AvailableConnectionsString method in the IProxy interface. +func (m MockProxy) AvailableConnectionsString() []string { + return nil +} + +// BusyConnectionsString is a mock implementation of the BusyConnectionsString method in the IProxy interface. +func (m MockProxy) BusyConnectionsString() []string { + return nil +} + +// GetName returns the name of the MockProxy. +func (m MockProxy) GetName() string { + return m.name +} diff --git a/network/roundrobin.go b/network/roundrobin.go new file mode 100644 index 00000000..0057f432 --- /dev/null +++ b/network/roundrobin.go @@ -0,0 +1,26 @@ +package network + +import ( + "errors" + "sync/atomic" + + gerr "github.com/gatewayd-io/gatewayd/errors" +) + +type RoundRobin struct { + proxies []IProxy + next atomic.Uint32 +} + +func NewRoundRobin(server *Server) *RoundRobin { + return &RoundRobin{proxies: server.Proxies} +} + +func (r *RoundRobin) NextProxy() (IProxy, *gerr.GatewayDError) { + proxiesLen := uint32(len(r.proxies)) + if proxiesLen == 0 { + return nil, gerr.ErrNoProxiesAvailable.Wrap(errors.New("proxy list is empty")) + } + nextIndex := r.next.Add(1) + return r.proxies[nextIndex%proxiesLen], nil +} diff --git a/network/roundrobin_test.go b/network/roundrobin_test.go new file mode 100644 index 00000000..ec430dc9 --- /dev/null +++ b/network/roundrobin_test.go @@ -0,0 +1,117 @@ +package network + +import ( + "math" + "sync" + "testing" +) + +// TestNewRoundRobin tests the NewRoundRobin function to ensure that it correctly initializes +// the round-robin load balancer with the expected number of proxies. +func TestNewRoundRobin(t *testing.T) { + proxies := []IProxy{ + MockProxy{name: "proxy1"}, + MockProxy{name: "proxy2"}, + MockProxy{name: "proxy3"}, + } + server := &Server{Proxies: proxies} + rr := NewRoundRobin(server) + + if len(rr.proxies) != len(proxies) { + t.Errorf("expected %d proxies, got %d", len(proxies), len(rr.proxies)) + } +} + +// TestRoundRobin_NextProxy tests the NextProxy method of the round-robin load balancer to ensure +// that it returns proxies in the expected order. +func TestRoundRobin_NextProxy(t *testing.T) { + proxies := []IProxy{ + MockProxy{name: "proxy1"}, + MockProxy{name: "proxy2"}, + MockProxy{name: "proxy3"}, + } + server := &Server{Proxies: proxies} + roundRobin := NewRoundRobin(server) + + expectedOrder := []string{"proxy2", "proxy3", "proxy1", "proxy2", "proxy3"} + + for testIndex, expected := range expectedOrder { + proxy, err := roundRobin.NextProxy() + if err != nil { + t.Fatalf("test %d: unexpected error from NextProxy: %v", testIndex, err) + } + mockProxy, ok := proxy.(MockProxy) + if !ok { + t.Fatalf("test %d: expected proxy of type MockProxy, got %T", testIndex, proxy) + } + if mockProxy.GetName() != expected { + t.Errorf("test %d: expected proxy name %s, got %s", testIndex, expected, mockProxy.GetName()) + } + } +} + +// TestRoundRobin_ConcurrentAccess tests the thread safety of the NextProxy method in the round-robin load balancer +// by invoking it concurrently from multiple goroutines and ensuring that the internal state is updated correctly. +func TestRoundRobin_ConcurrentAccess(t *testing.T) { + proxies := []IProxy{ + MockProxy{name: "proxy1"}, + MockProxy{name: "proxy2"}, + MockProxy{name: "proxy3"}, + } + server := &Server{Proxies: proxies} + roundRobin := NewRoundRobin(server) + + var waitGroup sync.WaitGroup + numGoroutines := 100 + waitGroup.Add(numGoroutines) + + for range numGoroutines { + go func() { + defer waitGroup.Done() + _, _ = roundRobin.NextProxy() + }() + } + + waitGroup.Wait() + nextIndex := roundRobin.next.Load() + if nextIndex != uint32(numGoroutines) { + t.Errorf("expected next index to be %d, got %d", numGoroutines, nextIndex) + } +} + +// TestNextProxyOverflow verifies that the round-robin proxy selection correctly handles +// the overflow of the internal counter. It sets the counter to a value close to the maximum +// uint32 value and ensures that the proxy selection wraps around as expected when the +// counter overflows. +func TestNextProxyOverflow(t *testing.T) { + // Create a server with a few mock proxies + server := &Server{ + Proxies: []IProxy{ + &MockProxy{}, + &MockProxy{}, + &MockProxy{}, + }, + } + roundRobin := NewRoundRobin(server) + + // Set the next value to near the max uint32 value to force an overflow + roundRobin.next.Store(math.MaxUint32 - 1) + + // Call NextProxy multiple times to trigger the overflow + for range 4 { + proxy, err := roundRobin.NextProxy() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if proxy == nil { + t.Fatal("Expected a proxy, got nil") + } + } + + // After overflow, next value should wrap around + expectedNextValue := uint32(2) // (MaxUint32 - 1 + 4) % ProxiesLen = 2 + actualNextValue := roundRobin.next.Load() + if actualNextValue != expectedNextValue { + t.Fatalf("Expected next value to be %v, got %v", expectedNextValue, actualNextValue) + } +} diff --git a/network/server.go b/network/server.go index 7025f584..1d3c6b99 100644 --- a/network/server.go +++ b/network/server.go @@ -48,7 +48,7 @@ type IServer interface { } type Server struct { - Proxy IProxy + Proxies []IProxy Logger zerolog.Logger PluginRegistry *plugin.Registry ctx context.Context //nolint:containedctx @@ -73,6 +73,11 @@ type Server struct { connections uint32 running *atomic.Bool stopServer chan struct{} + + // loadbalancer + loadbalancerStrategy LoadBalancerStrategy + LoadbalancerStrategyName string + connectionToProxyMap map[*ConnWrapper]IProxy } var _ IServer = (*Server)(nil) @@ -149,10 +154,18 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { } span.AddEvent("Ran the OnOpening hooks") + // Attempt to retrieve the next proxy. + proxy, err := s.loadbalancerStrategy.NextProxy() + if err != nil { + span.RecordError(err) + s.Logger.Error().Err(err).Msg("failed to retrieve next proxy") + return nil, Close + } + // Use the proxy to connect to the backend. Close the connection if the pool is exhausted. // This effectively get a connection from the pool and puts both the incoming and the server // connections in the pool of the busy connections. - if err := s.Proxy.Connect(conn); err != nil { + if err := proxy.Connect(conn); err != nil { if errors.Is(err, gerr.ErrPoolExhausted) { span.RecordError(err) return nil, Close @@ -165,6 +178,9 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) { return nil, None } + // Assign connection to proxy + s.connectionToProxyMap[conn] = proxy + // Run the OnOpened hooks. pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.PluginTimeout) defer cancel() @@ -225,15 +241,27 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action { span.AddEvent("Shutting down the server") return Shutdown } + + // Find the proxy associated with the given connection + proxy, exists := s.GetProxyForConnection(conn) + if !exists { + // Log an error and return Close if no matching proxy is found + s.Logger.Error().Msg("Failed to find proxy to disconnect it") + return Close + } + // Disconnect the connection from the proxy. This effectively removes the mapping between // the incoming and the server connections in the pool of the busy connections and either // recycles or disconnects the connections. - if err := s.Proxy.Disconnect(conn); err != nil { + if err := proxy.Disconnect(conn); err != nil { s.Logger.Error().Err(err).Msg("Failed to disconnect the server connection") span.RecordError(err) return Close } + // remove a connection from proxy connention map + s.RemoveConnectionFromMap(conn) + if conn.IsTLSEnabled() { metrics.TLSConnections.Dec() } @@ -303,7 +331,16 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti go func(server *Server, conn *ConnWrapper, stopConnection chan struct{}, stack *Stack) { for { server.Logger.Trace().Msg("Passing through traffic from client to server") - if err := server.Proxy.PassThroughToServer(conn, stack); err != nil { + + // Find the proxy associated with the given connection + proxy, exists := server.GetProxyForConnection(conn) + if !exists { + server.Logger.Error().Msg("Failed to find proxy that matches the connection") + stopConnection <- struct{}{} + break + } + + if err := proxy.PassThroughToServer(conn, stack); err != nil { server.Logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) stopConnection <- struct{}{} @@ -317,7 +354,15 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti go func(server *Server, conn *ConnWrapper, stopConnection chan struct{}, stack *Stack) { for { server.Logger.Trace().Msg("Passing through traffic from server to client") - if err := server.Proxy.PassThroughToClient(conn, stack); err != nil { + + // Find the proxy associated with the given connection + proxy, exists := server.GetProxyForConnection(conn) + if !exists { + server.Logger.Error().Msg("Failed to find proxy that matches the connection") + stopConnection <- struct{}{} + break + } + if err := proxy.PassThroughToClient(conn, stack); err != nil { server.Logger.Trace().Err(err).Msg("Failed to pass through traffic") span.RecordError(err) stopConnection <- struct{}{} @@ -352,8 +397,10 @@ func (s *Server) OnShutdown() { } span.AddEvent("Ran the OnShutdown hooks") - // Shutdown the proxy. - s.Proxy.Shutdown() + // Shutdown proxies. + for _, proxy := range s.Proxies { + proxy.Shutdown() + } // Set the server status to stopped. This is used to shutdown the server gracefully in OnClose. s.mu.Lock() @@ -573,8 +620,10 @@ func (s *Server) Shutdown() { _, span := otel.Tracer("gatewayd").Start(s.ctx, "Shutdown") defer span.End() - // Shutdown the proxy. - s.Proxy.Shutdown() + for _, proxy := range s.Proxies { + // Shutdown the proxy. + proxy.Shutdown() + } // Set the server status to stopped. This is used to shutdown the server gracefully in OnClose. s.mu.Lock() @@ -627,24 +676,26 @@ func NewServer( // Create the server. server := Server{ - ctx: serverCtx, - Network: srv.Network, - Address: srv.Address, - Options: srv.Options, - TickInterval: srv.TickInterval, - Status: config.Stopped, - EnableTLS: srv.EnableTLS, - CertFile: srv.CertFile, - KeyFile: srv.KeyFile, - HandshakeTimeout: srv.HandshakeTimeout, - Proxy: srv.Proxy, - Logger: srv.Logger, - PluginRegistry: srv.PluginRegistry, - PluginTimeout: srv.PluginTimeout, - mu: &sync.RWMutex{}, - connections: 0, - running: &atomic.Bool{}, - stopServer: make(chan struct{}), + ctx: serverCtx, + Network: srv.Network, + Address: srv.Address, + Options: srv.Options, + TickInterval: srv.TickInterval, + Status: config.Stopped, + EnableTLS: srv.EnableTLS, + CertFile: srv.CertFile, + KeyFile: srv.KeyFile, + HandshakeTimeout: srv.HandshakeTimeout, + Proxies: srv.Proxies, + Logger: srv.Logger, + PluginRegistry: srv.PluginRegistry, + PluginTimeout: srv.PluginTimeout, + mu: &sync.RWMutex{}, + connections: 0, + running: &atomic.Bool{}, + stopServer: make(chan struct{}), + connectionToProxyMap: make(map[*ConnWrapper]IProxy), + LoadbalancerStrategyName: srv.LoadbalancerStrategyName, } // Try to resolve the address and log an error if it can't be resolved. @@ -664,6 +715,12 @@ func NewServer( "GatewayD is listening on an unresolved address") } + st, err := NewLoadBalancerStrategy(&server) + if err != nil { + srv.Logger.Error().Err(err).Msg("Failed to create a loadbalancer strategy") + } + server.loadbalancerStrategy = st + return &server } @@ -673,3 +730,14 @@ func (s *Server) CountConnections() int { defer s.mu.RUnlock() return int(s.connections) } + +// GetProxyForConnection returns the proxy associated with the given connection. +func (s *Server) GetProxyForConnection(conn *ConnWrapper) (IProxy, bool) { + proxy, exists := s.connectionToProxyMap[conn] + return proxy, exists +} + +// RemoveConnectionFromMap removes the given connection from the connection-to-proxy map. +func (s *Server) RemoveConnectionFromMap(conn *ConnWrapper) { + delete(s.connectionToProxyMap, conn) +} diff --git a/network/server_test.go b/network/server_test.go index b090e126..c655afaa 100644 --- a/network/server_test.go +++ b/network/server_test.go @@ -114,11 +114,12 @@ func TestRunServer(t *testing.T) { Options: Option{ EnableTicker: true, }, - Proxy: proxy, - Logger: logger, - PluginRegistry: pluginRegistry, - PluginTimeout: config.DefaultPluginTimeout, - HandshakeTimeout: config.DefaultHandshakeTimeout, + Proxies: []IProxy{proxy}, + Logger: logger, + PluginRegistry: pluginRegistry, + PluginTimeout: config.DefaultPluginTimeout, + HandshakeTimeout: config.DefaultHandshakeTimeout, + LoadbalancerStrategyName: config.RoundRobinStrategy, }, ) assert.NotNil(t, server) diff --git a/plugin/utils.go b/plugin/utils.go index d117b5f3..bb4bc87d 100644 --- a/plugin/utils.go +++ b/plugin/utils.go @@ -42,6 +42,11 @@ func castToPrimitiveTypes(args map[string]interface{}) map[string]interface{} { } } args[key] = array + case map[string]map[string]any: + for _, valuemap := range value { + // Recursively cast nested maps. + args[key] = castToPrimitiveTypes(valuemap) + } // TODO: Add more types here as needed. default: args[key] = value diff --git a/testdata/gatewayd_tls.yaml b/testdata/gatewayd_tls.yaml index 604e7fc9..f5e24c74 100644 --- a/testdata/gatewayd_tls.yaml +++ b/testdata/gatewayd_tls.yaml @@ -12,15 +12,30 @@ metrics: clients: default: - address: localhost:5432 + activeWrites: + address: localhost:5432 + network: tcp + tcpKeepAlive: False + tcpKeepAlivePeriod: 30s + receiveChunkSize: 8192 + receiveDeadline: 0s + receiveTimeout: 0s + sendDeadline: 0s + dialTimeout: 60s + retries: 3 + backoff: 1s + backoffMultiplier: 2.0 + disableBackoffCaps: false pools: default: - size: 10 + activeWrites: + size: 10 proxies: default: - healthCheckPeriod: 60s # duration + activeWrites: + healthCheckPeriod: 60s # duration servers: default: