From 19fc482a5892e5414cc727e6c71d6aa0a3e07d1e Mon Sep 17 00:00:00 2001 From: Jessie Liu Date: Thu, 16 Nov 2023 11:24:27 -0800 Subject: [PATCH] Check for enough unwrapped shares before combining PiperOrigin-RevId: 583109015 Change-Id: Ib84b11c072c95b983c3bcc0eb73c165e777a1968 --- client/client.go | 51 +++++++++++++++++++++++++++++++++---------- client/client_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 11 deletions(-) diff --git a/client/client.go b/client/client.go index 98d751b..f035365 100644 --- a/client/client.go +++ b/client/client.go @@ -325,21 +325,21 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares // implementation handle the subset of shares. var unwrappedShares []shares.UnwrappedShare for i, wrapped := range wrappedShares { - glog.Infof("Attempting to unwrap share #%v", i+1) unwrapped := shares.UnwrappedShare{} kek := opts.kekInfos[i] + glog.Infof("Attempting to unwrap share #%v, URI %v", i+1, kek.GetKekUri()) switch x := kek.KekType.(type) { case *configpb.KekInfo_RsaFingerprint: key, err := PrivateKeyForRSAFingerprint(kek, opts.asymmetricKeys) if err != nil { - glog.Warningf("Failed to find public key for RSA fingerprint: %v", err) + glog.Errorf("Failed to find private key for RSA fingerprint: %v", err) continue } unwrapped.Share, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, key, wrapped.GetShare(), nil) if err != nil { - glog.Warningf("Error unwrapping key share: %v", err) + glog.Errorf("Error unwrapping key share for %v: %v", kek.GetKekUri(), err) continue } @@ -352,12 +352,14 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares kmsClient, err := kmsClients.Client(ctx, creds) if err != nil { - return nil, fmt.Errorf("error initializing Cloud KMS Client with credentials \"%v\": %v", creds, err) + glog.Errorf("Error initializing Cloud KMS Client with credentials \"%v\" for %v: %v", creds, kek.GetKekUri(), err) + continue } kmd, err := getKekURIMetadata(ctx, kmsClient, kek) if err != nil { - return nil, fmt.Errorf("Error retrieving KEK Metadata: %v", err) + glog.Errorf("Error retrieving KEK Metadata for %v: %v", kek.GetKekUri(), err) + continue } // Unwrap share via KMS. @@ -369,17 +371,17 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares } unwrapped.Share, err = cloudkms.UnwrapShare(ctx, kmsClient, unwrapOpts) if err != nil { - glog.Warningf("Error unwrapping key share: %v", err) + glog.Errorf("Error unwrapping key sharefor %v: %v", kmd.uri, err) continue } case rpb.ProtectionLevel_EXTERNAL: unwrapped.Share, err = c.ekmSecureSessionUnwrap(ctx, wrapped.GetShare(), *kmd) if err != nil { - glog.Warningf("Error unwrapping with external EKM for %v: %v", kmd.uri, err) + glog.Errorf("Error unwrapping with external EKM for %v: %v", kmd.uri, err) continue } default: - glog.Warningf("Unsupported protection level %v", pl) + glog.Errorf("Unsupported protection level for %v: %v", kek.GetKekUri(), pl) continue } @@ -388,16 +390,16 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares unwrapped.URI = kmd.uri default: - glog.Warningf("Unsupported KekInfo type: %v", x) + glog.Errorf("Unsupported KekInfo type for %v: %v", kek.GetKekUri(), x) continue } if !shares.ValidateShare(unwrapped.Share, wrapped.GetHash()) { - glog.Warningf("Unwrapped share %v does not have the expected hash", i) + glog.Errorf("Unwrapped share %v does not have the expected hash", i) continue } - glog.Infof("Successfully unwrapped share #%v", i+1) + glog.Infof("Successfully unwrapped share %v", unwrapped.URI) unwrappedShares = append(unwrappedShares, unwrapped) } @@ -483,6 +485,26 @@ func (c *StetClient) Encrypt(ctx context.Context, input io.Reader, output io.Wri } +// Returns whether the number of unwrapped shares is sufficient for combining the DEK based +// on the splitting +func enoughUnwrappedShares(shares []shares.UnwrappedShare, config *configpb.KeyConfig) error { + numShares := len(shares) + + // Return error if no unwrapped shares found. + if numShares == 0 { + return fmt.Errorf("no unwrapped shares") + } + + // Otherwise, verify the number of shares is enough for the specified shamir threshold. + if _, ok := config.GetKeySplittingAlgorithm().(*configpb.KeyConfig_Shamir); ok { + if int64(numShares) < config.GetShamir().GetThreshold() { + return fmt.Errorf("number of unwrapped shares %v is less than threshold needed %v", numShares, config.GetShamir().GetThreshold()) + } + } + + return nil +} + // Decrypt writes the decrypted data to the `output` writer, and returns the // key URIs used during decryption and the blob ID decrypted. func (c *StetClient) Decrypt(ctx context.Context, input io.Reader, output io.Writer, stetConfig *configpb.StetConfig) (*StetMetadata, error) { @@ -522,6 +544,13 @@ func (c *StetClient) Decrypt(ctx context.Context, input io.Reader, output io.Wri return nil, fmt.Errorf("error unwrapping and validating shares: %v", err) } + // Verify we have enough unwrapped shares for the key config. + if err := enoughUnwrappedShares(unwrappedShares, matchingKeyConfig); err != nil { + return nil, fmt.Errorf("not enough unwrapped shares to recombine DEK, see logs for unwrap details: %v", err) + } else if len(unwrappedShares) < len(matchingKeyConfig.GetKekInfos()) { + glog.Warningf("Recieved enough unwrapped shares to recombine DEK, but not all shares unwrapped successfully: %v of %v unwrapped, see logs for unwrap details.", len(unwrappedShares), len(matchingKeyConfig.GetKekInfos())) + } + combinedShares, err := shares.CombineUnwrappedShares(matchingKeyConfig, unwrappedShares) if err != nil { return nil, fmt.Errorf("error combining unwrapped shares: %v", err) diff --git a/client/client_test.go b/client/client_test.go index 8253b6e..8bdc803 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1772,3 +1772,54 @@ func TestNewConfspaceConfig(t *testing.T) { }) } } + +func TestEnoughUnwrappedShares(t *testing.T) { + testShare := shares.UnwrappedShare{[]byte("test share"), "test hash"} + testcases := []struct { + name string + shares []shares.UnwrappedShare + config *configpb.KeyConfig + expectErr bool + }{ + { + name: "With no split", + shares: []shares.UnwrappedShare{testShare}, + config: &configpb.KeyConfig{ + KeySplittingAlgorithm: &configpb.KeyConfig_NoSplit{true}, + }, + }, + { + name: "With shamir config", + shares: []shares.UnwrappedShare{testShare, testShare}, + config: &configpb.KeyConfig{ + KeySplittingAlgorithm: &configpb.KeyConfig_Shamir{&configpb.ShamirConfig{Threshold: 2, Shares: 2}}, + }, + }, + { + name: "Zero shares", + shares: []shares.UnwrappedShare{}, + config: &configpb.KeyConfig{ + KeySplittingAlgorithm: &configpb.KeyConfig_NoSplit{true}, + }, + expectErr: true, + }, + { + name: "Less shares than shamir threshold", + shares: []shares.UnwrappedShare{testShare}, + config: &configpb.KeyConfig{ + KeySplittingAlgorithm: &configpb.KeyConfig_Shamir{&configpb.ShamirConfig{Threshold: 2, Shares: 2}}, + }, + expectErr: true, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := enoughUnwrappedShares(tc.shares, tc.config) + + if (err != nil) != tc.expectErr { + t.Errorf("enoughWrappedShares did not return expected output: want (err == nil) == %v, got %v", tc.expectErr, err) + } + }) + } +}