Skip to content

Commit

Permalink
Prototype stream authentication
Browse files Browse the repository at this point in the history
Signed-off-by: Andrea Mazzotti <[email protected]>
  • Loading branch information
anmazzotti committed Nov 21, 2024
1 parent b3cc38f commit dfc94b5
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 28 deletions.
162 changes: 162 additions & 0 deletions internal/proto/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package proto

import (
"context"
"errors"
"fmt"
"strings"

"github.com/go-logr/logr"
"github.com/golang-jwt/jwt/v5"
"github.com/rancher-sandbox/cluster-api-provider-elemental/api/v1beta1"
"github.com/rancher-sandbox/cluster-api-provider-elemental/internal/log"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"k8s.io/apimachinery/pkg/types"

"sigs.k8s.io/controller-runtime/pkg/client"
)

const (
validatedHostKey = "ValidatedElementalHost"
)

var (
ErrPermissionDenied = errors.New("Permission Denied")
)

type contextKey string

type Authenticator interface {
CanGetElementalRegistration(registrationToken string, registrationPubKey []byte) (bool, error)
CanCreateElementalHost(registrationToken, hostToken string, registrationPubKey, hostPubKey []byte) (bool, error)
CanGetPatchDeleteElementalHost(hostToken string, hostPubKey []byte) (bool, error)
}

func NewAuthenticator() Authenticator {
return &authenticator{}
}

type authenticator struct {
}

func (a *authenticator) CanGetElementalRegistration(registrationToken string, registrationPubKey []byte) (bool, error) {
return true, nil
}
func (a *authenticator) CanCreateElementalHost(registrationToken, hostToken string, registrationPubKey, hostPubKey []byte) (bool, error) {
return true, nil
}
func (a *authenticator) CanGetPatchDeleteElementalHost(hostToken string, hostPubKey []byte) (bool, error) {
return true, nil
}

type AuthInterceptor interface {
}

func NewAuthInterceptor(k8sClient client.Client, logger logr.Logger) AuthInterceptor {
return &authInterceptor{
k8sClient: k8sClient,
logger: logger,
authenticator: NewAuthenticator(),
}
}

type authInterceptor struct {
k8sClient client.Client
logger logr.Logger
authenticator Authenticator
}

type wrappedHostStream struct {
grpc.ServerStream
ctx context.Context
}

func (s wrappedHostStream) Context() context.Context {
return s.ctx
}

func (a *authInterceptor) InterceptStream(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
switch info.FullMethod {
case "elemental/ReconcileHost":
md, ok := metadata.FromIncomingContext(ss.Context())
if !ok {
return status.Errorf(codes.InvalidArgument, "Missing metadata")
}
authorization, found := md["authorization"]
if !found || len(authorization) < 1 {
return status.Errorf(codes.Unauthenticated, "Missing authorization header")
}
host, err := a.ValidateHostRequest(ss.Context(), authorization[0])
if err != nil {
return status.Errorf(codes.PermissionDenied, "Permission Denied: %s", err.Error())
}
// Inject the validated ElementalHost in the wrapped stream context
wrappedStream := wrappedHostStream{ServerStream: ss, ctx: context.WithValue(ss.Context(), contextKey(validatedHostKey), host)}
return handler(srv, wrappedStream)
default:
a.logger.Info("Dropping unexpected call", "method", info.FullMethod)
return status.Errorf(codes.Unimplemented, "Could not authenticate method: %s", info.FullMethod)
}
}

func (a *authInterceptor) InterceptUnary(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
return nil, status.Errorf(codes.Unimplemented, "Uniplemented intecteptor")
}

func (a *authInterceptor) ValidateHostRequest(ctx context.Context, bearerToken string) (v1beta1.ElementalHost, error) {
host := v1beta1.ElementalHost{}

token, found := strings.CutPrefix(bearerToken, "Bearer ")
if !found {
return host, fmt.Errorf("not a 'Bearer' token")
}

_, err := jwt.Parse(token, func(parsedToken *jwt.Token) (any, error) {
// Extract the subject from parsed token
subject, err := parsedToken.Claims.GetSubject()
if err != nil {
return nil, fmt.Errorf("getting subject from token: %w", err)
}

subjectParts := strings.Split(subject, string(types.Separator))
if len(subjectParts) < 2 {
return nil, fmt.Errorf("parsing subject '%s': Bad format", subject)
}

// Fetch the ElementalHost from subject
hostKey := client.ObjectKey{
Namespace: subjectParts[0],
Name: subjectParts[1],
}

logger := a.logger.WithValues(log.KeyNamespace, hostKey.Namespace).
WithValues(log.KeyElementalHost, hostKey.Name)

if err := a.k8sClient.Get(ctx, hostKey, &host); err != nil {
logger.Error(err, "Could not get ElementalHost")
return nil, ErrPermissionDenied
}

// Verify signature using ElementalHost's PubKey
signingAlg := parsedToken.Method.Alg()
switch signingAlg {
case "EdDSA":
pubKey, err := jwt.ParseEdPublicKeyFromPEM([]byte(host.Spec.PubKey))
if err != nil {
logger.Error(err, "Could not parse ElementalHost.spec.PubKey")
return nil, ErrPermissionDenied
}
return pubKey, nil
default:
logger.Error(err, "JWT is using unsupported signing algorithm", "JWT Signing Alg", signingAlg)
return nil, ErrPermissionDenied
}
})
if err != nil {
return host, fmt.Errorf("validating JWT token: %w", err)
}
return host, nil
}
83 changes: 55 additions & 28 deletions internal/proto/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/rancher-sandbox/cluster-api-provider-elemental/internal/log"
pb "github.com/rancher-sandbox/cluster-api-provider-elemental/pkg/api/proto/v1"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
"sigs.k8s.io/controller-runtime/pkg/client"

Expand Down Expand Up @@ -125,34 +126,47 @@ func (s *server) GetBootstrap(context.Context, *pb.BootstrapRequest) (*pb.Bootst
return nil, status.Errorf(codes.Unimplemented, "method GetBootstrap not implemented")
}
func (s *server) ReconcileHost(stream grpc.BidiStreamingServer[pb.HostPatchRequest, pb.HostResponse]) error {
incoming, err := stream.Recv()
if errors.Is(err, io.EOF) {
s.logger.Info("Stream closed before any message was received")
return nil
// Fetch host from authenticated wrapped stream
validatedHost := stream.Context().Value(contextKey(validatedHostKey))
if validatedHost == nil {
s.logger.Info("Closing stream due to missing validated ElementalHost")
return status.Errorf(codes.Internal, "Missing validated ElementalHost")
}
if err != nil {
return fmt.Errorf("reading first time stream: %w", err)

host, ok := validatedHost.(v1beta1.ElementalHost)
if !ok {
s.logger.Info("Closing stream due to validated ElementalHost being incorrect type")
return status.Errorf(codes.Internal, "Validated ElementalHost is incorrect type")
}
// !! Validation/Auth can happen here !!
logger := s.logger.WithValues(log.KeyNamespace, incoming.Namespace).
WithValues(log.KeyElementalHost, incoming.Name)

// Since we no longer validate messages after the stream is open, it's important to set a static host key.
// This is to prevent the host from assuming different identities (patching other hosts) after authentication.
hostKey := client.ObjectKey{Namespace: incoming.Namespace, Name: incoming.Name}
logger := s.logger.WithValues(log.KeyNamespace, host.Namespace).
WithValues(log.KeyElementalHost, host.Name)

// Always send back a first response.
// This gives the consumer a chance to reconcile from previously unreceived messages,
// even if the ElementalHost has not mutated meanwhile.
if err := s.sendElementalHostToStream(hostKey, stream); err != nil {
return fmt.Errorf("sending first ElementalHost to stream: %w", err)
response, err := getElementalHostResponse(host)
if err != nil {
logger.Error(err, "Could not format HostResponse")
return status.Errorf(codes.Internal, "getting HostResponse: %s", err.Error())
}
if err := stream.Send(response); err != nil {
logger.Error(err, "Could not send HostResponse")
return status.Errorf(codes.DataLoss, "sending HostResponse: %s", err.Error())
}

// Since we no longer validate messages after the stream is open, it's important to set a static host key.
// This is to prevent the host from assuming different identities (patching other hosts) after authentication.
hostKey := client.ObjectKey{Namespace: host.Namespace, Name: host.Name}

// Asynchronously patch ElementalHost resource from stream input
// Note: stream.Recv() can be consumed concurrently to stream.Send()
readingErrors := make(chan error, 1)
var readingErrorCode codes.Code
go func() {
if err := s.updateElementalHostFromStream(logger, hostKey, stream); err != nil {
logger.Error(err, "Failed to consume stream")
if code, err := s.updateElementalHostFromStream(logger, hostKey, stream); err != nil {
readingErrorCode = code
readingErrors <- err
return
}
}()
Expand All @@ -165,47 +179,57 @@ func (s *server) ReconcileHost(stream grpc.BidiStreamingServer[pb.HostPatchReque
select {
case <-s.hosts[hostKey.String()]:
logger.Info("Sending update")
if err := s.sendElementalHostToStream(hostKey, stream); err != nil {
return fmt.Errorf("sending ElementalHost to stream: %w", err)
if code, err := s.sendElementalHostToStream(hostKey, stream); err != nil {
return status.Errorf(code, "sending ElementalHost to stream: %s", err.Error())
}
case <-stream.Context().Done():
// Stream is closed
return nil
case err := <-readingErrors:
// If we can no longer consume the stream, close it
logger.Error(err, "Failed to consume ElementalHost stream")
return status.Errorf(readingErrorCode, "consuming ElementalHost stream: %s", err.Error())
}
}
}

func (s *server) sendElementalHostToStream(key types.NamespacedName, stream grpc.BidiStreamingServer[pb.HostPatchRequest, pb.HostResponse]) error {
func (s *server) sendElementalHostToStream(key types.NamespacedName, stream grpc.BidiStreamingServer[pb.HostPatchRequest, pb.HostResponse]) (codes.Code, error) {
elementalHost := &v1beta1.ElementalHost{}
if err := s.k8sClient.Get(stream.Context(), key, elementalHost); err != nil {
return fmt.Errorf("getting ElementalHost: %w", err)
if apierrors.IsNotFound(err) {
return codes.NotFound, fmt.Errorf("ElementalHost '%s' not found", key)
}
return codes.Internal, fmt.Errorf("getting ElementalHost: %w", err)
}

response, err := getElementalHostResponse(*elementalHost)
if err != nil {
return fmt.Errorf("getting HostResponse: %w", err)
return codes.Internal, fmt.Errorf("getting HostResponse: %w", err)
}

if err := stream.Send(response); err != nil {
return fmt.Errorf("sending HostResponse: %w", err)
return codes.DataLoss, fmt.Errorf("sending HostResponse: %w", err)
}
return nil
return codes.OK, nil
}

func (s *server) updateElementalHostFromStream(logger logr.Logger, key types.NamespacedName, stream grpc.BidiStreamingServer[pb.HostPatchRequest, pb.HostResponse]) error {
func (s *server) updateElementalHostFromStream(logger logr.Logger, key types.NamespacedName, stream grpc.BidiStreamingServer[pb.HostPatchRequest, pb.HostResponse]) (codes.Code, error) {
for {
incoming, err := stream.Recv()
if errors.Is(err, io.EOF) {
logger.Info("Stream closed")
return nil
return codes.OK, nil
}
if err != nil {
return fmt.Errorf("reading stream: %w", err)
return codes.DataLoss, fmt.Errorf("reading stream: %w", err)
}
err = retry.RetryOnConflict(retry.DefaultRetry, func() error {
err = retry.RetryOnConflict(retry.DefaultBackoff, func() error {
// Always refresh resource on each attempt
elementalHost := &v1beta1.ElementalHost{}
if err := s.k8sClient.Get(stream.Context(), key, elementalHost); err != nil {
if apierrors.IsNotFound(err) {
return fmt.Errorf("ElementalHost '%s' not found", key.String())
}
return fmt.Errorf("getting ElementalHost: %w", err)
}

Expand All @@ -220,7 +244,10 @@ func (s *server) updateElementalHostFromStream(logger logr.Logger, key types.Nam
return patchHelper.Patch(stream.Context(), elementalHost)
})
if err != nil {
return fmt.Errorf("patching ElementalHost: %w", err)
if apierrors.IsNotFound(err) {
return codes.NotFound, fmt.Errorf("ElementalHost '%s' not found", key.String())
}
return codes.Internal, fmt.Errorf("patching ElementalHost: %w", err)
}
}
}
Expand Down

0 comments on commit dfc94b5

Please sign in to comment.