@@ -87,6 +87,8 @@ private static readonly Dictionary<string, SqlColumnEncryptionKeyStoreProvider>
8787 /// Instance-level list of custom key store providers. It can be set more than once by the user.
8888 private IReadOnlyDictionary < string , SqlColumnEncryptionKeyStoreProvider > _customColumnEncryptionKeyStoreProviders ;
8989
90+ private Func < SqlAuthenticationParameters , CancellationToken , Task < SqlAuthenticationToken > > _accessTokenCallback ;
91+
9092 internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
9193 _customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders . Count > 0 ;
9294
@@ -272,7 +274,7 @@ internal static List<string> GetColumnEncryptionSystemKeyStoreProvidersNames()
272274 }
273275
274276 /// <summary>
275- /// This function returns a list of the names of the custom providers currently registered. If the
277+ /// This function returns a list of the names of the custom providers currently registered. If the
276278 /// instance-level cache is not empty, that cache is used, else the global cache is used.
277279 /// </summary>
278280 /// <returns>Combined list of provider names</returns>
@@ -344,7 +346,7 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(IDictionary<st
344346 new ( customProviders , StringComparer . OrdinalIgnoreCase ) ;
345347
346348 // Set the dictionary to the ReadOnly dictionary.
347- // This method can be called more than once. Re-registering a new collection will replace the
349+ // This method can be called more than once. Re-registering a new collection will replace the
348350 // old collection of providers.
349351 _customColumnEncryptionKeyStoreProviders = customColumnEncryptionKeyStoreProviders ;
350352 }
@@ -584,7 +586,7 @@ public override string ConnectionString
584586 }
585587 set
586588 {
587- if ( _credential != null || _accessToken != null )
589+ if ( _credential != null || _accessToken != null || _accessTokenCallback != null )
588590 {
589591 SqlConnectionString connectionOptions = new SqlConnectionString ( value ) ;
590592 if ( _credential != null )
@@ -620,12 +622,18 @@ public override string ConnectionString
620622
621623 CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential ( connectionOptions ) ;
622624 }
623- else if ( _accessToken != null )
625+
626+ if ( _accessToken != null )
624627 {
625628 CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken ( connectionOptions ) ;
626629 }
630+
631+ if ( _accessTokenCallback != null )
632+ {
633+ CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback ( connectionOptions ) ;
634+ }
627635 }
628- ConnectionString_Set ( new SqlConnectionPoolKey ( value , _credential , _accessToken ) ) ;
636+ ConnectionString_Set ( new SqlConnectionPoolKey ( value , _credential , _accessToken , _accessTokenCallback ) ) ;
629637 _connectionString = value ; // Change _connectionString value only after value is validated
630638 CacheConnectionStringProperties ( ) ;
631639 }
@@ -685,11 +693,34 @@ public string AccessToken
685693 }
686694
687695 // Need to call ConnectionString_Set to do proper pool group check
688- ConnectionString_Set ( new SqlConnectionPoolKey ( _connectionString , credential : _credential , accessToken : value ) ) ;
696+ ConnectionString_Set ( new SqlConnectionPoolKey ( _connectionString , credential : _credential , accessToken : value , accessTokenCallback : null ) ) ;
689697 _accessToken = value ;
690698 }
691699 }
692700
701+ /// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/AccessTokenCallback/*' />
702+ public Func < SqlAuthenticationParameters , CancellationToken , Task < SqlAuthenticationToken > > AccessTokenCallback
703+ {
704+ get { return _accessTokenCallback ; }
705+ set
706+ {
707+ // If a connection is connecting or is ever opened, AccessToken callback cannot be set
708+ if ( ! InnerConnection . AllowSetConnectionString )
709+ {
710+ throw ADP . OpenConnectionPropertySet ( nameof ( AccessTokenCallback ) , InnerConnection . State ) ;
711+ }
712+
713+ if ( value != null )
714+ {
715+ // Check if the usage of AccessToken has any conflict with the keys used in connection string and credential
716+ CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback ( ( SqlConnectionString ) ConnectionOptions ) ;
717+ }
718+
719+ ConnectionString_Set ( new SqlConnectionPoolKey ( _connectionString , credential : _credential , accessToken : null , accessTokenCallback : value ) ) ;
720+ _accessTokenCallback = value ;
721+ }
722+ }
723+
693724 /// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/Database/*' />
694725 [ ResDescription ( StringsHelper . ResourceNames . SqlConnection_Database ) ]
695726 [ ResCategory ( StringsHelper . ResourceNames . SqlConnection_DataSource ) ]
@@ -970,6 +1001,7 @@ public SqlCredential Credential
9701001 }
9711002
9721003 CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential ( connectionOptions ) ;
1004+
9731005 if ( _accessToken != null )
9741006 {
9751007 throw ADP . InvalidMixedUsageOfCredentialAndAccessToken ( ) ;
@@ -979,7 +1011,7 @@ public SqlCredential Credential
9791011 _credential = value ;
9801012
9811013 // Need to call ConnectionString_Set to do proper pool group check
982- ConnectionString_Set ( new SqlConnectionPoolKey ( _connectionString , _credential , accessToken : _accessToken ) ) ;
1014+ ConnectionString_Set ( new SqlConnectionPoolKey ( _connectionString , _credential , accessToken : _accessToken , accessTokenCallback : _accessTokenCallback ) ) ;
9831015 }
9841016 }
9851017
@@ -1026,6 +1058,33 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(S
10261058 {
10271059 throw ADP . InvalidMixedUsageOfCredentialAndAccessToken ( ) ;
10281060 }
1061+
1062+ if ( _accessTokenCallback != null )
1063+ {
1064+ throw ADP . InvalidMixedUsageOfAccessTokenAndTokenCallback ( ) ;
1065+ }
1066+ }
1067+
1068+ // CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback: check if the usage of AccessTokenCallback has any conflict
1069+ // with the keys used in connection string and credential
1070+ // If there is any conflict, it throws InvalidOperationException
1071+ // This is to be used setter of ConnectionString and AccessTokenCallback properties
1072+ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback ( SqlConnectionString connectionOptions )
1073+ {
1074+ if ( UsesIntegratedSecurity ( connectionOptions ) )
1075+ {
1076+ throw ADP . InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity ( ) ;
1077+ }
1078+
1079+ if ( UsesAuthentication ( connectionOptions ) )
1080+ {
1081+ throw ADP . InvalidMixedUsageOfAccessTokenCallbackAndAuthentication ( ) ;
1082+ }
1083+
1084+ if ( _accessToken != null )
1085+ {
1086+ throw ADP . InvalidMixedUsageOfAccessTokenAndTokenCallback ( ) ;
1087+ }
10291088 }
10301089
10311090 /// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/DbProviderFactory/*' />
@@ -2128,7 +2187,7 @@ public static void ChangePassword(string connectionString, string newPassword)
21282187 throw ADP . InvalidArgumentLength ( nameof ( newPassword ) , TdsEnums . MAXLEN_NEWPASSWORD ) ;
21292188 }
21302189
2131- SqlConnectionPoolKey key = new SqlConnectionPoolKey ( connectionString , credential : null , accessToken : null ) ;
2190+ SqlConnectionPoolKey key = new SqlConnectionPoolKey ( connectionString , credential : null , accessToken : null , accessTokenCallback : null ) ;
21322191
21332192 SqlConnectionString connectionOptions = SqlConnectionFactory . FindSqlConnectionOptions ( key ) ;
21342193 if ( connectionOptions . IntegratedSecurity )
@@ -2177,7 +2236,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent
21772236 throw ADP . InvalidArgumentLength ( nameof ( newSecurePassword ) , TdsEnums . MAXLEN_NEWPASSWORD ) ;
21782237 }
21792238
2180- SqlConnectionPoolKey key = new SqlConnectionPoolKey ( connectionString , credential , accessToken : null ) ;
2239+ SqlConnectionPoolKey key = new SqlConnectionPoolKey ( connectionString , credential , accessToken : null , accessTokenCallback : null ) ;
21812240
21822241 SqlConnectionString connectionOptions = SqlConnectionFactory . FindSqlConnectionOptions ( key ) ;
21832242
@@ -2216,7 +2275,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString
22162275 if ( con != null )
22172276 con . Dispose ( ) ;
22182277 }
2219- SqlConnectionPoolKey key = new SqlConnectionPoolKey ( connectionString , credential , accessToken : null ) ;
2278+ SqlConnectionPoolKey key = new SqlConnectionPoolKey ( connectionString , credential , accessToken : null , accessTokenCallback : null ) ;
22202279
22212280 SqlConnectionFactory . SingletonInstance . ClearPool ( key ) ;
22222281 }
0 commit comments