Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 124 additions & 9 deletions pkg/billing/workflow_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@ package billing

import (
"context"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"encoding/json"
"fmt"

"github.com/golang-jwt/jwt/v5"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/emptypb"

auth "github.com/smartcontractkit/chainlink-common/pkg/nodeauth/jwt"
"github.com/smartcontractkit/chainlink-common/pkg/nodeauth/types"
pb "github.com/smartcontractkit/chainlink-protos/billing/go"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
Expand Down Expand Up @@ -126,24 +132,96 @@ func (wc *workflowClient) Close() error {
}

// addJWTAuth creates and signs a JWT token, then adds it to the context
func (wc *workflowClient) addJWTAuth(ctx context.Context, req any) (context.Context, error) {
// Returns the updated context and the JWT token string (for logging purposes)
func (wc *workflowClient) addJWTAuth(ctx context.Context, req any) (context.Context, string, error) {
// Skip authentication if no JWT manager provided
if wc.jwtGenerator == nil {
return ctx, nil
return ctx, "", nil
}

// Create JWT token using the JWT manager
jwtToken, err := wc.jwtGenerator.CreateJWTForRequest(req)
if err != nil {
return nil, fmt.Errorf("failed to create JWT: %w", err)
return nil, "", fmt.Errorf("failed to create JWT: %w", err)
}

// Add JWT to Authorization header
return metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+jwtToken), nil
return metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+jwtToken), jwtToken, nil
}

// parseJWTForLogging parses the JWT token without verification to extract claims for logging purposes
func parseJWTForLogging(tokenString string) *types.NodeJWTClaims {
if tokenString == "" {
return nil
}
parser := jwt.NewParser()
token, _, err := parser.ParseUnverified(tokenString, &types.NodeJWTClaims{})
if err != nil {
return nil
}

claims, ok := token.Claims.(*types.NodeJWTClaims)
if !ok {
return nil
}

return claims
}

// DigestDebugInfo contains debugging information about digest calculation
type DigestDebugInfo struct {
RequestType string
IsProtoMessage bool
SerializedLength int
RequestJSON string
MarshalSuccess bool
}

// calculateDigestWithDebugging calculates the request digest with detailed debugging information
func calculateDigestWithDebugging(req any) (string, DigestDebugInfo) {
debugInfo := DigestDebugInfo{
RequestType: fmt.Sprintf("%T", req),
}

var data []byte
if m, ok := req.(proto.Message); ok {
debugInfo.IsProtoMessage = true
// Use protobuf canonical serialization
serialized, err := proto.Marshal(m)
if err == nil {
debugInfo.MarshalSuccess = true
data = serialized
} else {
debugInfo.MarshalSuccess = false
// fallback to string representation if marshal fails
data = fmt.Appendf(nil, "%v", req)
}
} else if s, ok := req.(fmt.Stringer); ok {
debugInfo.IsProtoMessage = false
debugInfo.MarshalSuccess = true
data = []byte(s.String())
} else {
debugInfo.IsProtoMessage = false
debugInfo.MarshalSuccess = true
data = fmt.Appendf(nil, "%v", req)
}

debugInfo.SerializedLength = len(data)

// Create JSON representation of the request for human-readable debugging
if jsonBytes, err := json.Marshal(req); err == nil {
debugInfo.RequestJSON = string(jsonBytes)
} else {
// Fallback to string representation
debugInfo.RequestJSON = fmt.Sprintf("%+v", req)
}

hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:]), debugInfo
}

func (wc *workflowClient) GetOrganizationCreditsByWorkflow(ctx context.Context, req *pb.GetOrganizationCreditsByWorkflowRequest) (*pb.GetOrganizationCreditsByWorkflowResponse, error) {
ctx, err := wc.addJWTAuth(ctx, req)
ctx, _, err := wc.addJWTAuth(ctx, req)
if err != nil {
wc.logger.Errorw("Failed to add custom auth header to GetOrganizationCreditsByWorkflow request", "error", err)
return nil, err
Expand All @@ -157,7 +235,7 @@ func (wc *workflowClient) GetOrganizationCreditsByWorkflow(ctx context.Context,
}

func (wc *workflowClient) GetWorkflowExecutionRates(ctx context.Context, req *pb.GetWorkflowExecutionRatesRequest) (*pb.GetWorkflowExecutionRatesResponse, error) {
ctx, err := wc.addJWTAuth(ctx, req)
ctx, _, err := wc.addJWTAuth(ctx, req)
if err != nil {
wc.logger.Errorw("Failed to add custom auth header to GetWorkflowExecutionRates request", "error", err)
return nil, err
Expand All @@ -171,7 +249,7 @@ func (wc *workflowClient) GetWorkflowExecutionRates(ctx context.Context, req *pb
}

func (wc *workflowClient) ReserveCredits(ctx context.Context, req *pb.ReserveCreditsRequest) (*pb.ReserveCreditsResponse, error) {
ctx, err := wc.addJWTAuth(ctx, req)
ctx, _, err := wc.addJWTAuth(ctx, req)
if err != nil {
wc.logger.Errorw("Failed to add JWT auth to ReserveCredits request", "error", err)
return nil, err
Expand All @@ -185,14 +263,51 @@ func (wc *workflowClient) ReserveCredits(ctx context.Context, req *pb.ReserveCre
}

func (wc *workflowClient) SubmitWorkflowReceipt(ctx context.Context, req *pb.SubmitWorkflowReceiptRequest) (*emptypb.Empty, error) {
ctx, err := wc.addJWTAuth(ctx, req)
// Calculate digest and get debug info before adding JWT
clientDigest, debugInfo := calculateDigestWithDebugging(req)

// Add JWT authentication
ctx, jwtToken, err := wc.addJWTAuth(ctx, req)
if err != nil {
wc.logger.Errorw("Failed to add JWT auth to SubmitWorkflowReceipt request", "error", err)
return nil, err
}

// Parse JWT claims for logging
parsedClaims := parseJWTForLogging(jwtToken)

// Log detailed request information (matching billing service format)
logFields := []any{
"method", "SubmitWorkflowReceipt",
"jwt_token", jwtToken,
"client_calculated_digest", clientDigest,
"request_type", debugInfo.RequestType,
"is_proto_message", debugInfo.IsProtoMessage,
"serialized_length", debugInfo.SerializedLength,
"request_json", debugInfo.RequestJSON,
"marshal_success", debugInfo.MarshalSuccess,
}

if parsedClaims != nil {
logFields = append(logFields,
"parsed_public_key", parsedClaims.PublicKey,
"parsed_digest_from_jwt", parsedClaims.Digest,
"digest_match", parsedClaims.Digest == clientDigest,
"parsed_issuer", parsedClaims.Issuer,
"parsed_subject", parsedClaims.Subject,
"parsed_expires_at", parsedClaims.ExpiresAt,
"parsed_issued_at", parsedClaims.IssuedAt,
"parsed_audience", parsedClaims.Audience,
)
}

wc.logger.Infow("Sending SubmitWorkflowReceipt request", logFields...)

// Make the actual RPC call
resp, err := wc.client.SubmitWorkflowReceipt(ctx, req)
if err != nil {
wc.logger.Errorw("SubmitWorkflowReceipt failed", "error", err)
// Log error with the same detailed information for debugging
wc.logger.Errorw("SubmitWorkflowReceipt failed", append(logFields, "error", err)...)
return nil, err
}
return resp, nil
Expand Down
69 changes: 64 additions & 5 deletions pkg/billing/workflow_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ func TestWorkflowClient_AddJWTAuthToContext(t *testing.T) {
}

ctx := context.Background()
newCtx, err := wc.addJWTAuth(ctx, req)
newCtx, jwtToken, err := wc.addJWTAuth(ctx, req)
require.NoError(t, err)
require.Equal(t, expectedToken, jwtToken, "Expected JWT token to be returned")

// Verify JWT is added to metadata
md, ok := metadata.FromOutgoingContext(newCtx)
Expand All @@ -260,8 +261,9 @@ func TestWorkflowClient_NoSigningKey(t *testing.T) {
logger: logger.Test(t),
jwtGenerator: nil,
}
newCtx, err := wc.addJWTAuth(ctx, req)
newCtx, jwtToken, err := wc.addJWTAuth(ctx, req)
require.NoError(t, err)
require.Empty(t, jwtToken, "Expected empty JWT token when no JWT generator is provided")

// Should return the same context
assert.Equal(t, ctx, newCtx)
Expand All @@ -280,8 +282,9 @@ func TestWorkflowClient_VerifySignature_Invalid(t *testing.T) {
}

ctx := context.Background()
_, err := wc.addJWTAuth(ctx, req)
_, jwtToken, err := wc.addJWTAuth(ctx, req)
require.Error(t, err)
require.Empty(t, jwtToken, "Expected empty JWT token on error")
assert.Contains(t, err.Error(), "failed to create JWT")
}

Expand All @@ -299,12 +302,14 @@ func TestWorkflowClient_RepeatedSign(t *testing.T) {
}

ctx1 := context.Background()
newCtx1, err := wc.addJWTAuth(ctx1, req)
newCtx1, jwtToken1, err := wc.addJWTAuth(ctx1, req)
require.NoError(t, err)
require.Equal(t, expectedToken, jwtToken1, "Expected JWT token to match")

ctx2 := context.Background()
newCtx2, err := wc.addJWTAuth(ctx2, req)
newCtx2, jwtToken2, err := wc.addJWTAuth(ctx2, req)
require.NoError(t, err)
require.Equal(t, expectedToken, jwtToken2, "Expected JWT token to match")

// Both should have the same token since we're mocking the same response
md1, ok := metadata.FromOutgoingContext(newCtx1)
Expand All @@ -314,3 +319,57 @@ func TestWorkflowClient_RepeatedSign(t *testing.T) {

assert.Equal(t, md1["authorization"], md2["authorization"], "Expected same authorization header for same request")
}

func TestWorkflowClient_SubmitWorkflowReceipt_WithLogging(t *testing.T) {
// Start a test gRPC server
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
grpcServer := grpc.NewServer()
testSrv := &testWorkflowServer{}
pb.RegisterCreditReservationServiceServer(grpcServer, testSrv)
go func() {
_ = grpcServer.Serve(lis)
}()
defer grpcServer.Stop()

addr := lis.Addr().String()

// Create mock JWT manager for testing
mockJWT := mocks.NewJWTGenerator(t)
expectedToken := "test.jwt.token.for.logging"

// Create a test request
req := &pb.SubmitWorkflowReceiptRequest{
WorkflowOwner: "test-owner",
WorkflowId: "test-workflow-id",
WorkflowExecutionId: "test-execution-id",
WorkflowRegistryAddress: "0x123",
WorkflowRegistryChainSelector: 1,
CreditsConsumed: "100",
}

// Expect JWT creation
mockJWT.EXPECT().CreateJWTForRequest(req).Return(expectedToken, nil).Once()

lggr := logger.Test(t)
wc, err := NewWorkflowClient(lggr, addr,
WithWorkflowTransportCredentials(insecure.NewCredentials()),
WithJWTGenerator(mockJWT),
WithServerName("localhost"),
)
require.NoError(t, err)
defer func(wc WorkflowClient) {
_ = wc.Close()
}(wc)

// Call SubmitWorkflowReceipt
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
resp, err := wc.SubmitWorkflowReceipt(ctx, req)
require.NoError(t, err)
require.NotNil(t, resp)

// Note: In a real test, we would inspect the logs to verify the detailed
// logging is happening. For now, we're just ensuring the method works
// with the new logging code.
}
Loading