Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions aibtrace/aibtrace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package aibtrace

import (
"context"

"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)

type (
traceInterceptionAttrsContextKey struct{}
traceRequestBridgeAttrsContextKey struct{}
)

const (
// trace attribute key constants
RequestPath = "request_path"

InterceptionID = "interception_id"
InitiatorID = "user_id"
Provider = "provider"
Model = "model"
Streaming = "streaming"
IsBedrock = "aws_bedrock"

PassthroughURL = "passthrough_url"
PassthroughMethod = "passthrough_method"

MCPInput = "mcp_input"
MCPProxyName = "mcp_proxy_name"
MCPToolName = "mcp_tool_name"
MCPServerName = "mcp_server_name"
MCPServerURL = "mcp_server_url"
MCPToolCount = "mcp_tool_count"

APIKeyID = "api_key_id"
)

func WithInterceptionAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context {
return context.WithValue(ctx, traceInterceptionAttrsContextKey{}, traceAttrs)
}

func InterceptionAttributesFromContext(ctx context.Context) []attribute.KeyValue {
attrs, ok := ctx.Value(traceInterceptionAttrsContextKey{}).([]attribute.KeyValue)
if !ok {
return nil
}

return attrs
}

func WithRequestBridgeAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context {
return context.WithValue(ctx, traceRequestBridgeAttrsContextKey{}, traceAttrs)
}

func RequestBridgeAttributesFromContext(ctx context.Context) []attribute.KeyValue {
attrs, ok := ctx.Value(traceRequestBridgeAttrsContextKey{}).([]attribute.KeyValue)
if !ok {
return nil
}

return attrs
}

func EndSpanErr(span trace.Span, err *error) {
if span == nil {
return
}

if err != nil && *err != nil {
span.SetStatus(codes.Error, (*err).Error())
}
span.End()
}
7 changes: 4 additions & 3 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"cdr.dev/slog"
"github.com/coder/aibridge/mcp"
"go.opentelemetry.io/otel/trace"

"github.com/hashicorp/go-multierror"
)
Expand Down Expand Up @@ -47,20 +48,20 @@ var _ http.Handler = &RequestBridge{}
// A [Recorder] is also required to record prompt, tool, and token use.
//
// mcpProxy will be closed when the [RequestBridge] is closed.
func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, logger slog.Logger) (*RequestBridge, error) {
func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, metrics *Metrics, tracer trace.Tracer, logger slog.Logger) (*RequestBridge, error) {
mux := http.NewServeMux()

for _, provider := range providers {
// Add the known provider-specific routes which are bridged (i.e. intercepted and augmented).
for _, path := range provider.BridgedRoutes() {
mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics))
mux.HandleFunc(path, newInterceptionProcessor(provider, logger, recorder, mcpProxy, metrics, tracer))
}

// Any requests which passthrough to this will be reverse-proxied to the upstream.
//
// We have to whitelist the known-safe routes because an API key with elevated privileges (i.e. admin) might be
// configured, so we should just reverse-proxy known-safe routes.
ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics)
ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), metrics, tracer)
for _, path := range provider.PassthroughRoutes() {
prefix := fmt.Sprintf("/%s", provider.Name())
route := fmt.Sprintf("%s%s", prefix, path)
Expand Down
124 changes: 65 additions & 59 deletions bridge_integration_test.go

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ require (
github.com/openai/openai-go/v2 v2.7.0
)

require (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
require (
// Tracing-related libs.
require (

Also, thanks for getting rid of go-cmp 👍

github.com/google/go-cmp v0.7.0
go.opentelemetry.io/otel v1.38.0
go.opentelemetry.io/otel/sdk v1.38.0
go.opentelemetry.io/otel/trace v1.38.0
)

require (
github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect
Expand All @@ -46,6 +53,8 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/lipgloss v0.7.1 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
Expand All @@ -61,14 +70,13 @@ require (
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/rivo/uniseg v0.4.4 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/spf13/cast v1.7.1 // indirect
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opentelemetry.io/otel v1.33.0 // indirect
go.opentelemetry.io/otel/trace v1.33.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/term v0.34.0 // indirect
Expand Down
23 changes: 13 additions & 10 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
Expand Down Expand Up @@ -130,14 +131,16 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/otel v1.33.0 h1:/FerN9bax5LoK51X/sI0SVYrjSE0/yUL7DpxW4K3FWw=
go.opentelemetry.io/otel v1.33.0/go.mod h1:SUUkR6csvUQl+yjReHu5uM3EtVV7MBm5FHKRlNx4I8I=
go.opentelemetry.io/otel/metric v1.33.0 h1:r+JOocAyeRVXD8lZpjdQjzMadVZp2M4WmQ+5WtEnklQ=
go.opentelemetry.io/otel/metric v1.33.0/go.mod h1:L9+Fyctbp6HFTddIxClbQkjtubW6O9QS3Ann/M82u6M=
go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiMWgE=
go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4=
go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s=
go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
Expand Down
16 changes: 16 additions & 0 deletions intercept_anthropic_messages_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ import (
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
aibtrace "github.com/coder/aibridge/aibtrace"
"github.com/coder/aibridge/mcp"
"github.com/google/uuid"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"

"cdr.dev/slog"
)
Expand All @@ -27,6 +30,7 @@ type AnthropicMessagesInterceptionBase struct {
cfg AnthropicConfig
bedrockCfg *AWSBedrockConfig

tracer trace.Tracer
logger slog.Logger

recorder Recorder
Expand Down Expand Up @@ -59,6 +63,18 @@ func (i *AnthropicMessagesInterceptionBase) Model() string {
return string(i.req.Model)
}

func (s *AnthropicMessagesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
return []attribute.KeyValue{
attribute.String(aibtrace.RequestPath, r.URL.Path),
attribute.String(aibtrace.InterceptionID, s.id.String()),
attribute.String(aibtrace.InitiatorID, actorFromContext(r.Context()).id),
attribute.String(aibtrace.Provider, ProviderAnthropic),
attribute.String(aibtrace.Model, s.Model()),
attribute.Bool(aibtrace.Streaming, streaming),
attribute.Bool(aibtrace.IsBedrock, s.bedrockCfg != nil),
}
}

func (i *AnthropicMessagesInterceptionBase) injectTools() {
if i.req == nil || i.mcpProxy == nil {
return
Expand Down
32 changes: 25 additions & 7 deletions intercept_anthropic_messages_blocking.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package aibridge

import (
"context"
"encoding/json"
"fmt"
"net/http"
Expand All @@ -10,7 +11,10 @@ import (
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/google/uuid"
mcplib "github.com/mark3labs/mcp-go/mcp" // TODO: abstract this away so callers need no knowledge of underlying lib.
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"

aibtrace "github.com/coder/aibridge/aibtrace"
"github.com/coder/aibridge/mcp"

"cdr.dev/slog"
Expand All @@ -22,29 +26,35 @@ type AnthropicMessagesBlockingInterception struct {
AnthropicMessagesInterceptionBase
}

func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicMessagesBlockingInterception {
func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig, tracer trace.Tracer) *AnthropicMessagesBlockingInterception {
return &AnthropicMessagesBlockingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{
id: id,
req: req,
cfg: cfg,
bedrockCfg: bedrockCfg,
tracer: tracer,
}}
}

func (s *AnthropicMessagesBlockingInterception) Setup(logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) {
s.AnthropicMessagesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy)
func (i *AnthropicMessagesBlockingInterception) Setup(logger slog.Logger, recorder Recorder, mcpProxy mcp.ServerProxier) {
i.AnthropicMessagesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy)
}

func (i *AnthropicMessagesBlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return i.AnthropicMessagesInterceptionBase.baseTraceAttributes(r, false)
}

func (s *AnthropicMessagesBlockingInterception) Streaming() bool {
return false
}

func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) error {
func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) {
if i.req == nil {
return fmt.Errorf("developer error: req is nil")
}

ctx := r.Context()
ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(r.Context())...))
defer aibtrace.EndSpanErr(span, &outErr)

i.injectTools()

Expand Down Expand Up @@ -77,7 +87,8 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr
var cumulativeUsage anthropic.Usage

for {
resp, err = svc.New(ctx, messages)
// TODO add outer loop span (https://github.com/coder/aibridge/issues/67)
resp, err = i.traceNewMessage(ctx, svc, messages) // traces svc.New(ctx, msgParams) call
if err != nil {
if isConnError(err) {
// Can't write a response, just error out.
Expand Down Expand Up @@ -166,7 +177,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr
continue
}

res, err := tool.Call(ctx, tc.Input)
res, err := tool.Call(ctx, i.tracer, tc.Input)

_ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{
InterceptionID: i.ID().String(),
Expand Down Expand Up @@ -285,3 +296,10 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr

return nil
}

func (i *AnthropicMessagesBlockingInterception) traceNewMessage(ctx context.Context, svc anthropic.MessageService, msgParams anthropic.MessageNewParams) (_ *anthropic.Message, outErr error) {
ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(aibtrace.InterceptionAttributesFromContext(ctx)...))
defer aibtrace.EndSpanErr(span, &outErr)

return svc.New(ctx, msgParams)
}
Loading