Skip to content

Commit

Permalink
Check for enough unwrapped shares before combining
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583109015
Change-Id: Ib84b11c072c95b983c3bcc0eb73c165e777a1968
  • Loading branch information
jessieqliu authored and copybara-github committed Nov 16, 2023
1 parent b777577 commit 19fc482
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 11 deletions.
51 changes: 40 additions & 11 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,21 +325,21 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares
// implementation handle the subset of shares.
var unwrappedShares []shares.UnwrappedShare
for i, wrapped := range wrappedShares {
glog.Infof("Attempting to unwrap share #%v", i+1)
unwrapped := shares.UnwrappedShare{}
kek := opts.kekInfos[i]
glog.Infof("Attempting to unwrap share #%v, URI %v", i+1, kek.GetKekUri())

switch x := kek.KekType.(type) {
case *configpb.KekInfo_RsaFingerprint:
key, err := PrivateKeyForRSAFingerprint(kek, opts.asymmetricKeys)
if err != nil {
glog.Warningf("Failed to find public key for RSA fingerprint: %v", err)
glog.Errorf("Failed to find private key for RSA fingerprint: %v", err)
continue
}

unwrapped.Share, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, key, wrapped.GetShare(), nil)
if err != nil {
glog.Warningf("Error unwrapping key share: %v", err)
glog.Errorf("Error unwrapping key share for %v: %v", kek.GetKekUri(), err)
continue
}

Expand All @@ -352,12 +352,14 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares

kmsClient, err := kmsClients.Client(ctx, creds)
if err != nil {
return nil, fmt.Errorf("error initializing Cloud KMS Client with credentials \"%v\": %v", creds, err)
glog.Errorf("Error initializing Cloud KMS Client with credentials \"%v\" for %v: %v", creds, kek.GetKekUri(), err)
continue
}

kmd, err := getKekURIMetadata(ctx, kmsClient, kek)
if err != nil {
return nil, fmt.Errorf("Error retrieving KEK Metadata: %v", err)
glog.Errorf("Error retrieving KEK Metadata for %v: %v", kek.GetKekUri(), err)
continue
}

// Unwrap share via KMS.
Expand All @@ -369,17 +371,17 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares
}
unwrapped.Share, err = cloudkms.UnwrapShare(ctx, kmsClient, unwrapOpts)
if err != nil {
glog.Warningf("Error unwrapping key share: %v", err)
glog.Errorf("Error unwrapping key sharefor %v: %v", kmd.uri, err)
continue
}
case rpb.ProtectionLevel_EXTERNAL:
unwrapped.Share, err = c.ekmSecureSessionUnwrap(ctx, wrapped.GetShare(), *kmd)
if err != nil {
glog.Warningf("Error unwrapping with external EKM for %v: %v", kmd.uri, err)
glog.Errorf("Error unwrapping with external EKM for %v: %v", kmd.uri, err)
continue
}
default:
glog.Warningf("Unsupported protection level %v", pl)
glog.Errorf("Unsupported protection level for %v: %v", kek.GetKekUri(), pl)
continue
}

Expand All @@ -388,16 +390,16 @@ func (c *StetClient) unwrapAndValidateShares(ctx context.Context, wrappedShares
unwrapped.URI = kmd.uri

default:
glog.Warningf("Unsupported KekInfo type: %v", x)
glog.Errorf("Unsupported KekInfo type for %v: %v", kek.GetKekUri(), x)
continue
}

if !shares.ValidateShare(unwrapped.Share, wrapped.GetHash()) {
glog.Warningf("Unwrapped share %v does not have the expected hash", i)
glog.Errorf("Unwrapped share %v does not have the expected hash", i)
continue
}

glog.Infof("Successfully unwrapped share #%v", i+1)
glog.Infof("Successfully unwrapped share %v", unwrapped.URI)
unwrappedShares = append(unwrappedShares, unwrapped)
}

Expand Down Expand Up @@ -483,6 +485,26 @@ func (c *StetClient) Encrypt(ctx context.Context, input io.Reader, output io.Wri

}

// Returns whether the number of unwrapped shares is sufficient for combining the DEK based
// on the splitting
func enoughUnwrappedShares(shares []shares.UnwrappedShare, config *configpb.KeyConfig) error {
numShares := len(shares)

// Return error if no unwrapped shares found.
if numShares == 0 {
return fmt.Errorf("no unwrapped shares")
}

// Otherwise, verify the number of shares is enough for the specified shamir threshold.
if _, ok := config.GetKeySplittingAlgorithm().(*configpb.KeyConfig_Shamir); ok {
if int64(numShares) < config.GetShamir().GetThreshold() {
return fmt.Errorf("number of unwrapped shares %v is less than threshold needed %v", numShares, config.GetShamir().GetThreshold())
}
}

return nil
}

// Decrypt writes the decrypted data to the `output` writer, and returns the
// key URIs used during decryption and the blob ID decrypted.
func (c *StetClient) Decrypt(ctx context.Context, input io.Reader, output io.Writer, stetConfig *configpb.StetConfig) (*StetMetadata, error) {
Expand Down Expand Up @@ -522,6 +544,13 @@ func (c *StetClient) Decrypt(ctx context.Context, input io.Reader, output io.Wri
return nil, fmt.Errorf("error unwrapping and validating shares: %v", err)
}

// Verify we have enough unwrapped shares for the key config.
if err := enoughUnwrappedShares(unwrappedShares, matchingKeyConfig); err != nil {
return nil, fmt.Errorf("not enough unwrapped shares to recombine DEK, see logs for unwrap details: %v", err)
} else if len(unwrappedShares) < len(matchingKeyConfig.GetKekInfos()) {
glog.Warningf("Recieved enough unwrapped shares to recombine DEK, but not all shares unwrapped successfully: %v of %v unwrapped, see logs for unwrap details.", len(unwrappedShares), len(matchingKeyConfig.GetKekInfos()))
}

combinedShares, err := shares.CombineUnwrappedShares(matchingKeyConfig, unwrappedShares)
if err != nil {
return nil, fmt.Errorf("error combining unwrapped shares: %v", err)
Expand Down
51 changes: 51 additions & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1772,3 +1772,54 @@ func TestNewConfspaceConfig(t *testing.T) {
})
}
}

func TestEnoughUnwrappedShares(t *testing.T) {
testShare := shares.UnwrappedShare{[]byte("test share"), "test hash"}
testcases := []struct {
name string
shares []shares.UnwrappedShare
config *configpb.KeyConfig
expectErr bool
}{
{
name: "With no split",
shares: []shares.UnwrappedShare{testShare},
config: &configpb.KeyConfig{
KeySplittingAlgorithm: &configpb.KeyConfig_NoSplit{true},
},
},
{
name: "With shamir config",
shares: []shares.UnwrappedShare{testShare, testShare},
config: &configpb.KeyConfig{
KeySplittingAlgorithm: &configpb.KeyConfig_Shamir{&configpb.ShamirConfig{Threshold: 2, Shares: 2}},
},
},
{
name: "Zero shares",
shares: []shares.UnwrappedShare{},
config: &configpb.KeyConfig{
KeySplittingAlgorithm: &configpb.KeyConfig_NoSplit{true},
},
expectErr: true,
},
{
name: "Less shares than shamir threshold",
shares: []shares.UnwrappedShare{testShare},
config: &configpb.KeyConfig{
KeySplittingAlgorithm: &configpb.KeyConfig_Shamir{&configpb.ShamirConfig{Threshold: 2, Shares: 2}},
},
expectErr: true,
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
err := enoughUnwrappedShares(tc.shares, tc.config)

if (err != nil) != tc.expectErr {
t.Errorf("enoughWrappedShares did not return expected output: want (err == nil) == %v, got %v", tc.expectErr, err)
}
})
}
}

0 comments on commit 19fc482

Please sign in to comment.