diff --git a/.gitignore b/.gitignore index cfaad76..1b17cb7 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ *.pem +app-gateway-go diff --git a/main.go b/main.go index 76f8226..63be15f 100644 --- a/main.go +++ b/main.go @@ -61,6 +61,7 @@ const ( gatewayVerboseEnvironmentVariable = "VERBOSE" logSecretsEnvironmentVariable = "LOG_SECRETS" targetRewritesVariables = "TARGET_REWRITES" + prometheusConfigVariable = "PROMETHEUS_CONFIG" ) var versionFlag = flag.Bool("version", false, "print name and version to stdout") @@ -198,8 +199,6 @@ func main() { debugResponse := getBoolEnv(gatewayDebugEnvironmentVariable, false) verbose := getBoolEnv(gatewayVerboseEnvironmentVariable, false) - monitoringServiceName := getStringEnv(monitoringServiceNameEnvironmentVariable, defaultMonitoringServiceName) - configID := uint8(getUintEnv(configurationIdEnvironmentVariable, 0)) config, err := ohttp.NewConfigFromSeed(configID, hpke.KEM_X25519_KYBER768_DRAFT00, hpke.KDF_HKDF_SHA256, hpke.AEAD_AES128GCM, seed) if err != nil { @@ -263,23 +262,40 @@ func main() { } // Configure metrics - metricsHost := os.Getenv(statsdHostVariable) - metricsPort := os.Getenv(statsdPortVariable) - metricsTimeout, err := strconv.ParseInt(os.Getenv(statsdTimeoutVariable), 10, 64) - if err != nil { - log.Printf("Failed parsing metrics timeout: %s", err) - metricsTimeout = 100 - } - client, err := createStatsDClient(metricsHost, metricsPort, int(metricsTimeout)) - if err != nil { - log.Fatalf("Failed to create statsd client: %s", err) - } - defer client.Close() + var metricsFactory MetricsFactory + + if prometheusConfigJSON := os.Getenv(prometheusConfigVariable); prometheusConfigJSON != "" { + log.Printf("prometheus config: %s", prometheusConfigJSON) + var prometheusConfig PrometheusConfig + if err := json.Unmarshal([]byte(prometheusConfigJSON), &prometheusConfig); err != nil { + log.Fatalf("Failed to parse Prometheus config: %s", err) + } + + metricsFactory, err = NewPrometheusMetricsFactory(prometheusConfig) + if err != nil { + log.Fatalf("Failed to configure Prometheus metrics: %s", err) + } + } else { + // Default to StatsD metrics + monitoringServiceName := getStringEnv(monitoringServiceNameEnvironmentVariable, defaultMonitoringServiceName) + metricsHost := os.Getenv(statsdHostVariable) + metricsPort := os.Getenv(statsdPortVariable) + metricsTimeout, err := strconv.ParseInt(getStringEnv(statsdTimeoutVariable, "100"), 10, 64) + if err != nil { + log.Fatalf("Failed parsing metrics timeout: %s", err) + metricsTimeout = 100 + } + client, err := createStatsDClient(metricsHost, metricsPort, int(metricsTimeout)) + if err != nil { + log.Fatalf("Failed to create statsd client: %s", err) + } + defer client.Close() - metricsFactory := &StatsDMetricsFactory{ - serviceName: monitoringServiceName, - metricsName: "ohttp_gateway_duration", - client: client, + metricsFactory = &StatsDMetricsFactory{ + serviceName: monitoringServiceName, + metricsName: "ohttp_gateway_duration", + client: client, + } } // Load endpoint configuration defaults diff --git a/prometheus_metrics.go b/prometheus_metrics.go new file mode 100644 index 0000000..3d103e3 --- /dev/null +++ b/prometheus_metrics.go @@ -0,0 +1,88 @@ +// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "errors" + "fmt" + "log" + "net/http" + "time" + + "github.com/prometheus/client_golang/prometheus" + //"github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +type PrometheusConfig struct { + Host string + Port string + ScrapePath string + MetricName string +} + +type PrometheusMetrics struct { + startedAt time.Time + histogram prometheus.ObserverVec +} + +func (p *PrometheusMetrics) Fire(result string) { + observer := p.histogram.With(prometheus.Labels{"method": "unknown", "status": "unknown", "result": result}) + p.observe(observer) +} + +func (p *PrometheusMetrics) ResponseStatus(method string, status int) { + observer := p.histogram.With(prometheus.Labels{"method": method, "status": fmt.Sprint(status), "result": "unknown"}) + p.observe(observer) +} + +func (p *PrometheusMetrics) observe(observer prometheus.Observer) { + elapsed := time.Now().Sub(p.startedAt) + observer.Observe(float64(elapsed.Milliseconds())) +} + +type PrometheusMetricsFactory struct { + metricName string +} + +func NewPrometheusMetricsFactory(config PrometheusConfig) (MetricsFactory, error) { + log.Printf("prometheus config: %+v\n", config) + + serveMux := http.NewServeMux() + serveMux.Handle(config.ScrapePath, promhttp.Handler()) + server := http.Server{ + Addr: config.Host + ":" + config.Port, + Handler: serveMux, + } + + go func() { + log.Printf("Listening for Prometheus scrapes on %s:%s\n", config.Host, config.Port) + log.Fatal(server.ListenAndServe()) + }() + + return &PrometheusMetricsFactory{metricName: config.MetricName}, nil +} + +func (p PrometheusMetricsFactory) Create(eventName string) Metrics { + histogram := prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: p.metricName, + }, []string{"eventName", "status", "method", "result"}) + + if err := prometheus.Register(histogram); err != nil { + are := &prometheus.AlreadyRegisteredError{} + if errors.As(err, are) { + // Use previously registered metric collector + histogram = are.ExistingCollector.(*prometheus.HistogramVec) + } else { + // There's no other reason prometheus.Register should fail and the interface won't let + // us return an error. + panic(err) + } + } + + return &PrometheusMetrics{ + startedAt: time.Now(), + histogram: histogram.MustCurryWith(prometheus.Labels{"eventName": eventName}), + } +}