diff --git a/doc/samples/SqlConnection_AccessTokenCallback.cs b/doc/samples/SqlConnection_AccessTokenCallback.cs new file mode 100644 index 0000000000..c761f640fd --- /dev/null +++ b/doc/samples/SqlConnection_AccessTokenCallback.cs @@ -0,0 +1,35 @@ +using System; +using System.Data; +// +using Microsoft.Data.SqlClient; +using Azure.Identity; + +class Program +{ + static void Main() + { + OpenSqlConnection(); + Console.ReadLine(); + } + + private static void OpenSqlConnection() + { + string connectionString = GetConnectionString(); + using (SqlConnection connection = new SqlConnection("Data Source=contoso.database.windows.net;Initial Catalog=AdventureWorks;") + { + AccessTokenCallback = async (authParams, cancellationToken) => + { + var cred = new DefaultAzureCredential(); + string scope = authParams.Resource.EndsWith(s_defaultScopeSuffix) ? authParams.Resource : authParams.Resource + s_defaultScopeSuffix; + var token = await cred.GetTokenAsync(new TokenRequestContext(new[] { scope }), cancellationToken); + return new SqlAuthenticationToken(token.Token, token.ExpiresOn); + } + }) + { + connection.Open(); + Console.WriteLine("ServerVersion: {0}", connection.ServerVersion); + Console.WriteLine("State: {0}", connection.State); + } + } +} +// diff --git a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml index bfc53dddac..08a9cc88fd 100644 --- a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml +++ b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml @@ -129,6 +129,22 @@ using (SqlConnection connection = new SqlConnection(connectionString)) The access token for the connection. To be added. + + Gets or sets the access token callback for the connection. + + The Func that takes a and and returns a . + + . + + [!code-csharp[SqlConnection_AccessTokenCallback Example#1](~/../sqlclient/doc/samples/SqlConnection_AccessTokenCallback.cs#1)] + + ]]> + + The AccessTokenCallback is combined with other conflicting authentication configurations. + To be added. To be added. diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs index f10c66b386..92cb042b76 100644 --- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs @@ -4,6 +4,7 @@ // NOTE: The current Microsoft.VSDesigner editor attributes are implemented for System.Data.SqlClient, and are not publicly available. // New attributes that are designed to work with Microsoft.Data.SqlClient and are publicly documented should be included in future. + [assembly: System.CLSCompliant(true)] namespace Microsoft.Data { @@ -839,6 +840,8 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collect /// [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] public System.Guid ClientConnectionId { get { throw null; } } + /// + public System.Func> AccessTokenCallback { get { throw null; } set { } } /// /// for internal test only diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs index dd23098d8f..cf13db3389 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -87,6 +87,8 @@ private static readonly Dictionary /// Instance-level list of custom key store providers. It can be set more than once by the user. private IReadOnlyDictionary _customColumnEncryptionKeyStoreProviders; + private Func> _accessTokenCallback; + internal bool HasColumnEncryptionKeyStoreProvidersRegistered => _customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0; @@ -272,7 +274,7 @@ internal static List GetColumnEncryptionSystemKeyStoreProvidersNames() } /// - /// This function returns a list of the names of the custom providers currently registered. If the + /// This function returns a list of the names of the custom providers currently registered. If the /// instance-level cache is not empty, that cache is used, else the global cache is used. /// /// Combined list of provider names @@ -344,7 +346,7 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(IDictionary + public Func> AccessTokenCallback + { + get { return _accessTokenCallback; } + set + { + // If a connection is connecting or is ever opened, AccessToken callback cannot be set + if (!InnerConnection.AllowSetConnectionString) + { + throw ADP.OpenConnectionPropertySet(nameof(AccessTokenCallback), InnerConnection.State); + } + + if (value != null) + { + // Check if the usage of AccessToken has any conflict with the keys used in connection string and credential + CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback((SqlConnectionString)ConnectionOptions); + } + + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: value)); + _accessTokenCallback = value; + } + } + /// [ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)] [ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)] @@ -970,6 +1001,7 @@ public SqlCredential Credential } CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions); + if (_accessToken != null) { throw ADP.InvalidMixedUsageOfCredentialAndAccessToken(); @@ -979,7 +1011,7 @@ public SqlCredential Credential _credential = value; // Need to call ConnectionString_Set to do proper pool group check - ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken)); + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback)); } } @@ -1026,6 +1058,33 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(S { throw ADP.InvalidMixedUsageOfCredentialAndAccessToken(); } + + if(_accessTokenCallback != null) + { + throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback(); + } + } + + // CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback: check if the usage of AccessTokenCallback has any conflict + // with the keys used in connection string and credential + // If there is any conflict, it throws InvalidOperationException + // This is to be used setter of ConnectionString and AccessTokenCallback properties + private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(SqlConnectionString connectionOptions) + { + if (UsesIntegratedSecurity(connectionOptions)) + { + throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity(); + } + + if (UsesAuthentication(connectionOptions)) + { + throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndAuthentication(); + } + + if(_accessToken != null) + { + throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback(); + } } /// @@ -2128,7 +2187,7 @@ public static void ChangePassword(string connectionString, string newPassword) throw ADP.InvalidArgumentLength(nameof(newPassword), TdsEnums.MAXLEN_NEWPASSWORD); } - SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null); + SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null); SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key); if (connectionOptions.IntegratedSecurity) @@ -2177,7 +2236,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent throw ADP.InvalidArgumentLength(nameof(newSecurePassword), TdsEnums.MAXLEN_NEWPASSWORD); } - SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null); + SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null); SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key); @@ -2216,7 +2275,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString if (con != null) con.Dispose(); } - SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null); + SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null); SqlConnectionFactory.SingletonInstance.ClearPool(key); } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs index 93450aec7b..7526f38623 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs @@ -133,7 +133,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt opt = new SqlConnectionString(opt, instanceName, userInstance: false, setEnlistValue: null); poolGroupProviderInfo = null; // null so we do not pass to constructor below... } - return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool); + return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool, key.AccessTokenCallback); } protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 66631151c9..cc727103df 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -130,6 +130,7 @@ internal sealed class SqlInternalConnectionTds : SqlInternalConnection, IDisposa // The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken. SqlFedAuthToken _fedAuthToken = null; internal byte[] _accessTokenInBytes; + internal readonly Func> _accessTokenCallback; private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper; private readonly SqlAuthenticationProviderManager _sqlAuthenticationProviderManager; @@ -434,19 +435,20 @@ internal SqlConnectionTimeoutErrorInternal TimeoutErrorInternal // the new Login7 packet will always write out the new password (or a length of zero and no bytes if not present) // internal SqlInternalConnectionTds( - DbConnectionPoolIdentity identity, - SqlConnectionString connectionOptions, - SqlCredential credential, - object providerInfo, - string newPassword, - SecureString newSecurePassword, - bool redirectedUserInstance, - SqlConnectionString userConnectionOptions = null, // NOTE: userConnectionOptions may be different to connectionOptions if the connection string has been expanded (see SqlConnectionString.Expand) - SessionData reconnectSessionData = null, - bool applyTransientFaultHandling = false, - string accessToken = null, - DbConnectionPool pool = null - ) : base(connectionOptions) + DbConnectionPoolIdentity identity, + SqlConnectionString connectionOptions, + SqlCredential credential, + object providerInfo, + string newPassword, + SecureString newSecurePassword, + bool redirectedUserInstance, + SqlConnectionString userConnectionOptions = null, // NOTE: userConnectionOptions may be different to connectionOptions if the connection string has been expanded (see SqlConnectionString.Expand) + SessionData reconnectSessionData = null, + bool applyTransientFaultHandling = false, + string accessToken = null, + DbConnectionPool pool = null, + Func> accessTokenCallback = null) : base(connectionOptions) { #if DEBUG @@ -479,6 +481,8 @@ internal SqlInternalConnectionTds( _accessTokenInBytes = System.Text.Encoding.Unicode.GetBytes(accessToken); } + _accessTokenCallback = accessTokenCallback; + _activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper(); _sqlAuthenticationProviderManager = SqlAuthenticationProviderManager.Instance; @@ -1327,7 +1331,8 @@ private void Login(ServerInfo server, TimeoutTimer timeout, string newPassword, || ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryMSI || ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryDefault // Since AD Integrated may be acting like Windows integrated, additionally check _fedAuthRequired - || (ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated && _fedAuthRequired)) + || (ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated && _fedAuthRequired) + || _accessTokenCallback != null) { requestedFeatures |= TdsEnums.FeatureExtension.FedAuth; _federatedAuthenticationInfoRequested = true; @@ -2152,6 +2157,7 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo) { Debug.Assert((ConnectionOptions._hasUserIdKeyword && ConnectionOptions._hasPasswordKeyword) || _credential != null + || _accessTokenCallback != null || ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive || ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow || ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity @@ -2354,7 +2360,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo) string username = null; var authProvider = _sqlAuthenticationProviderManager.GetProvider(ConnectionOptions.Authentication); - if (authProvider == null) + if (authProvider == null && _accessTokenCallback == null) throw SQL.CannotFindAuthProvider(ConnectionOptions.Authentication.ToString()); // retry getting access token once if MsalException.error_code is unknown_error. @@ -2365,11 +2371,11 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo) try { var authParamsBuilder = new SqlAuthenticationParameters.Builder( - authenticationMethod: ConnectionOptions.Authentication, - resource: fedAuthInfo.spn, - authority: fedAuthInfo.stsurl, - serverName: ConnectionOptions.DataSource, - databaseName: ConnectionOptions.InitialCatalog) + authenticationMethod: ConnectionOptions.Authentication, + resource: fedAuthInfo.spn, + authority: fedAuthInfo.stsurl, + serverName: ConnectionOptions.DataSource, + databaseName: ConnectionOptions.InitialCatalog) .WithConnectionId(_clientConnectionId) .WithConnectionTimeout(ConnectionOptions.ConnectTimeout); switch (ConnectionOptions.Authentication) @@ -2439,7 +2445,38 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo) } break; default: - throw SQL.UnsupportedAuthenticationSpecified(ConnectionOptions.Authentication); + if (_accessTokenCallback == null) + { + throw SQL.UnsupportedAuthenticationSpecified(ConnectionOptions.Authentication); + } + + if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying) + { + _fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken; + } + else + { + if (_credential != null) + { + username = _credential.UserId; + authParamsBuilder.WithUserId(username).WithPassword(_credential.Password); + } + else + { + authParamsBuilder.WithUserId(ConnectionOptions.UserID); + authParamsBuilder.WithPassword(ConnectionOptions.Password); + } + SqlAuthenticationParameters parameters = authParamsBuilder; + CancellationTokenSource cts = new(); + // Use Connection timeout value to cancel token acquire request after certain period of time.(int) + if (_timeout.MillisecondsRemaining < Int32.MaxValue) + { + cts.CancelAfter((int)_timeout.MillisecondsRemaining); + } + _fedAuthToken = Task.Run(async () => await _accessTokenCallback(parameters, cts.Token)).GetAwaiter().GetResult().ToSqlFedAuthToken(); + _activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken; + } + break; } Debug.Assert(_fedAuthToken.accessToken != null, "AccessToken should not be null."); @@ -2488,17 +2525,21 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo) // Deal with normal MsalExceptions. catch (MsalException msalException) { - if (MsalError.UnknownError != msalException.ErrorCode - || _timeout.IsExpired - || _timeout.MillisecondsRemaining <= sleepInterval) + if (MsalError.UnknownError != msalException.ErrorCode || _timeout.IsExpired || _timeout.MillisecondsRemaining <= sleepInterval) { SqlClientEventSource.Log.TryTraceEvent(" {0}", msalException.ErrorCode); throw ADP.CreateSqlException(msalException, ConnectionOptions, this, username); } - SqlClientEventSource.Log.TryAdvancedTraceEvent(" {0}, sleeping {1}[Milliseconds]", ObjectID, sleepInterval); - SqlClientEventSource.Log.TryAdvancedTraceEvent(" {0}, remaining {1}[Milliseconds]", ObjectID, _timeout.MillisecondsRemaining); + SqlClientEventSource.Log.TryAdvancedTraceEvent( + " {0}, sleeping {1}[Milliseconds]", + ObjectID, + sleepInterval); + SqlClientEventSource.Log.TryAdvancedTraceEvent( + " {0}, remaining {1}[Milliseconds]", + ObjectID, + _timeout.MillisecondsRemaining); Thread.Sleep(sleepInterval); sleepInterval *= 2; @@ -2506,7 +2547,21 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo) // All other exceptions from MSAL/Azure Identity APIs catch (Exception e) { - throw SqlException.CreateException(new() { new(0, (byte)0x00, (byte)TdsEnums.FATAL_ERROR_CLASS, ConnectionOptions.DataSource, e.Message, ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0) }, "", this, e); + throw SqlException.CreateException( + new() + { + new( + 0, + (byte)0x00, + (byte)TdsEnums.FATAL_ERROR_CLASS, + ConnectionOptions.DataSource, + e.Message, + ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, + 0) + }, + "", + this, + e); } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index e3de792213..85a4159ca7 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -1052,7 +1052,8 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake( { // Validate Certificate if Trust Server Certificate=false and Encryption forced (EncryptionOptions.ON) from Server. bool shouldValidateServerCert = (_encryptionOption == EncryptionOptions.ON && !trustServerCert) || - (_connHandler._accessTokenInBytes != null && !trustServerCert); + ((_connHandler._accessTokenInBytes != null || _connHandler._accessTokenCallback != null) + && !trustServerCert); uint info = (shouldValidateServerCert ? TdsEnums.SNI_SSL_VALIDATE_CERTIFICATE : 0) | (is2005OrLater ? TdsEnums.SNI_SSL_USE_SCHANNEL_CACHE : 0); @@ -1114,7 +1115,7 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake( // Or AccessToken is not null, mean token based authentication is used. if ((_connHandler.ConnectionOptions != null && _connHandler.ConnectionOptions.Authentication != SqlAuthenticationMethod.NotSpecified) - || _connHandler._accessTokenInBytes != null) + || _connHandler._accessTokenInBytes != null || _connHandler._accessTokenCallback != null) { fedAuthRequired = payload[payloadOffset] == 0x01 ? true : false; } @@ -7925,7 +7926,14 @@ internal int WriteFedAuthFeatureRequest(FederatedAuthenticationFeatureExtensionD workflow = TdsEnums.MSALWORKFLOW_ACTIVEDIRECTORYDEFAULT; break; default: - Debug.Assert(false, "Unrecognized Authentication type for fedauth MSAL request"); + if (_connHandler._accessTokenCallback != null) + { + workflow = TdsEnums.MSALWORKFLOW_ACTIVEDIRECTORYTOKENCREDENTIAL; + } + else + { + Debug.Assert(false, "Unrecognized Authentication type for fedauth MSAL request"); + } break; } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs b/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs index f6fbcf8306..cf900c4553 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs @@ -429,6 +429,15 @@ internal static string ADP_InvalidMixedUsageOfAccessTokenAndIntegratedSecurity { } } + /// + /// Looks up a localized string similar to Cannot set the AccessToken property if the AccessTokenCallback has been set.. + /// + internal static string ADP_InvalidMixedUsageOfAccessTokenAndTokenCallback { + get { + return ResourceManager.GetString("ADP_InvalidMixedUsageOfAccessTokenAndTokenCallback", resourceCulture); + } + } + /// /// Looks up a localized string similar to Cannot set the AccessToken property if 'UserID', 'UID', 'Password', or 'PWD' has been specified in connection string.. /// @@ -438,6 +447,24 @@ internal static string ADP_InvalidMixedUsageOfAccessTokenAndUserIDPassword { } } + /// + /// Looks up a localized string similar to Cannot set the AccessTokenCallback property if the 'Integrated Security' connection string keyword has been set to 'true' or 'SSPI'.. + /// + internal static string ADP_InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity { + get { + return ResourceManager.GetString("ADP_InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Cannot set the AccessTokenCallback property if 'Authentication=Active Directory Default' has been specified in the connection string.. + /// + internal static string ADP_InvalidMixedUsageOfAuthenticationAndTokenCallback { + get { + return ResourceManager.GetString("ADP_InvalidMixedUsageOfAuthenticationAndTokenCallback", resourceCulture); + } + } + /// /// Looks up a localized string similar to Cannot set the Credential property if the AccessToken property is already set.. /// @@ -950,7 +977,7 @@ internal static string Data_InvalidOffsetLength { return ResourceManager.GetString("Data_InvalidOffsetLength", resourceCulture); } } - + /// /// Looks up a localized string similar to Internal error occurred when retrying the download of the HGS root certificate after the initial request failed. Contact Customer Support Services.. /// @@ -1915,7 +1942,7 @@ internal static string SNI_ERROR_9 { } /// - /// Looks up a localized string similar to Incorrect physicalConnection type. + /// Looks up a localized string similar to Incorrect physicalConnection type.. /// internal static string SNI_IncorrectPhysicalConnectionType { get { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx b/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx index 45dd4459d9..7f52b2556f 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx +++ b/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx @@ -1941,4 +1941,13 @@ Socket did not throw expected '{0}' with error code '{1}'. + + Cannot set the AccessToken property if the AccessTokenCallback has been set. + + + Cannot set the AccessTokenCallback property if the 'Integrated Security' connection string keyword has been set to 'true' or 'SSPI'. + + + Cannot set the AccessTokenCallback property if 'Authentication=Active Directory Default' has been specified in the connection string. + diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs index ca229487be..4663cc151f 100644 --- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs @@ -4,6 +4,7 @@ // NOTE: The current Microsoft.VSDesigner editor attributes are implemented for System.Data.SqlClient, and are not publicly available. // New attributes that are designed to work with Microsoft.Data.SqlClient and are publicly documented should be included in future. + [assembly: System.CLSCompliant(true)] [assembly: System.Resources.NeutralResourcesLanguageAttribute("en-US")] namespace Microsoft.Data @@ -767,6 +768,8 @@ public SqlConnection(string connectionString, Microsoft.Data.SqlClient.SqlCreden [System.ComponentModel.BrowsableAttribute(false)] [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] public string AccessToken { get { throw null; } set { } } + /// + public System.Func> AccessTokenCallback { get { throw null; } set { } } /// [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] public System.Guid ClientConnectionId { get { throw null; } } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs index f9c02e8efc..23893afcf6 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -71,6 +71,8 @@ private static Dictionary s_systemC /// Instance-level list of custom key store providers. It can be set more than once by the user. private IReadOnlyDictionary _customColumnEncryptionKeyStoreProviders; + private Func> _accessTokenCallback; + internal bool HasColumnEncryptionKeyStoreProvidersRegistered => _customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0; @@ -739,7 +741,30 @@ public string AccessToken _accessToken = value; // Need to call ConnectionString_Set to do proper pool group check - ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo)); + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, null)); + } + } + + /// + public Func> AccessTokenCallback + { + get { return _accessTokenCallback; } + set + { + // If a connection is connecting or is ever opened, AccessToken callback cannot be set + if (!InnerConnection.AllowSetConnectionString) + { + throw ADP.OpenConnectionPropertySet(nameof(AccessTokenCallback), InnerConnection.State); + } + + if (value != null) + { + // Check if the usage of AccessToken has any conflict with the keys used in connection string and credential + CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback((SqlConnectionString)ConnectionOptions); + } + + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, null, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, value)); + _accessTokenCallback = value; } } @@ -776,7 +801,7 @@ override public string ConnectionString } set { - if (_credential != null || _accessToken != null) + if (_credential != null || _accessToken != null || _accessTokenCallback != null) { SqlConnectionString connectionOptions = new SqlConnectionString(value); if (_credential != null) @@ -812,12 +837,18 @@ override public string ConnectionString CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions); } - else if (_accessToken != null) + + if (_accessToken != null) { CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(connectionOptions); } + + if (_accessTokenCallback != null) + { + CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions); + } } - ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo)); + ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, _accessTokenCallback)); _connectionString = value; // Change _connectionString value only after value is validated CacheConnectionStringProperties(); } @@ -1154,17 +1185,17 @@ public SqlCredential Credential } CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions); + if (_accessToken != null) { throw ADP.InvalidMixedUsageOfCredentialAndAccessToken(); } - } _credential = value; // Need to call ConnectionString_Set to do proper pool group check - ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo)); + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, _accessTokenCallback)); } } @@ -1221,6 +1252,33 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(S { throw ADP.InvalidMixedUsageOfAccessTokenAndCredential(); } + + if (_accessTokenCallback != null) + { + throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback(); + } + } + + // CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback: check if the usage of AccessTokenCallback has any conflict + // with the keys used in connection string and credential + // If there is any conflict, it throws InvalidOperationException + // This is to be used setter of ConnectionString and AccessTokenCallback properties + private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(SqlConnectionString connectionOptions) + { + if (UsesIntegratedSecurity(connectionOptions)) + { + throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity(); + } + + if (UsesAuthentication(connectionOptions)) + { + throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndAuthentication(); + } + + if (_accessToken != null) + { + throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback(); + } } // @@ -2699,7 +2757,7 @@ public static void ChangePassword(string connectionString, string newPassword) throw ADP.InvalidArgumentLength("newPassword", TdsEnums.MAXLEN_NEWPASSWORD); } - SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null); + SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null, accessTokenCallback: null); SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key); if (connectionOptions.IntegratedSecurity || connectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated) @@ -2755,7 +2813,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent throw ADP.InvalidArgumentLength("newSecurePassword", TdsEnums.MAXLEN_NEWPASSWORD); } - SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null); + SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null, accessTokenCallback: null); SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key); @@ -2800,7 +2858,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString throw SQL.ChangePasswordRequires2005(); } } - SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null); + SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null, accessTokenCallback: null); SqlConnectionFactory.SingletonInstance.ClearPool(key); } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs index e2051784a3..aca7bfc151 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs @@ -142,7 +142,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt opt = new SqlConnectionString(opt, instanceName, false /* user instance=false */, null /* do not modify the Enlist value */); poolGroupProviderInfo = null; // null so we do not pass to constructor below... } - result = new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, key.ServerCertificateValidationCallback, key.ClientCertificateRetrievalCallback, pool, key.AccessToken, key.OriginalNetworkAddressInfo, applyTransientFaultHandling: applyTransientFaultHandling); + result = new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, key.ServerCertificateValidationCallback, key.ClientCertificateRetrievalCallback, pool, key.AccessToken, key.OriginalNetworkAddressInfo, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessTokenCallback); } return result; } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index da623f874e..f6b81a533c 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -136,6 +136,7 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa // The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken. SqlFedAuthToken _fedAuthToken = null; internal byte[] _accessTokenInBytes; + internal readonly Func> _accessTokenCallback; private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper; private readonly SqlAuthenticationProviderManager _sqlAuthenticationProviderManager; @@ -432,7 +433,9 @@ internal SqlInternalConnectionTds( DbConnectionPool pool = null, string accessToken = null, SqlClientOriginalNetworkAddressInfo originalNetworkAddressInfo = null, - bool applyTransientFaultHandling = false) : base(connectionOptions) + bool applyTransientFaultHandling = false, + Func> accessTokenCallback = null) : base(connectionOptions) { #if DEBUG @@ -487,6 +490,8 @@ internal SqlInternalConnectionTds( _accessTokenInBytes = System.Text.Encoding.Unicode.GetBytes(accessToken); } + _accessTokenCallback = accessTokenCallback; + _activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper(); _sqlAuthenticationProviderManager = SqlAuthenticationProviderManager.Instance; @@ -1591,7 +1596,8 @@ private void Login(ServerInfo server, TimeoutTimer timeout, string newPassword, || ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryMSI || ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryDefault // Since AD Integrated may be acting like Windows integrated, additionally check _fedAuthRequired - || (ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated && _fedAuthRequired)) + || (ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated && _fedAuthRequired) + || _accessTokenCallback != null) { requestedFeatures |= TdsEnums.FeatureExtension.FedAuth; _federatedAuthenticationInfoRequested = true; @@ -2578,6 +2584,7 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo) { Debug.Assert((ConnectionOptions._hasUserIdKeyword && ConnectionOptions._hasPasswordKeyword) || _credential != null + || _accessTokenCallback != null || ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive || ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryMSI @@ -2772,13 +2779,13 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo) int numberOfAttempts = 0; // Object that will be returned to the caller, containing all required data about the token. - SqlFedAuthToken fedAuthToken = new SqlFedAuthToken(); + _fedAuthToken = new SqlFedAuthToken(); // Username to use in error messages. string username = null; var authProvider = _sqlAuthenticationProviderManager.GetProvider(ConnectionOptions.Authentication); - if (authProvider == null) + if (authProvider == null && _accessTokenCallback == null) throw SQL.CannotFindAuthProvider(ConnectionOptions.Authentication.ToString()); // retry getting access token once if MsalException.error_code is unknown_error. @@ -2802,14 +2809,14 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo) username = TdsEnums.NTAUTHORITYANONYMOUSLOGON; if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying) { - fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken; + _fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken; } else { // We use Task.Run here in all places to execute task synchronously in the same context. // Fixes block-over-async deadlock possibilities https://github.com/dotnet/SqlClient/issues/1209 - fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken(); - _activeDirectoryAuthTimeoutRetryHelper.CachedToken = fedAuthToken; + _fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken(); + _activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken; } break; case SqlAuthenticationMethod.ActiveDirectoryInteractive: @@ -2819,20 +2826,20 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo) case SqlAuthenticationMethod.ActiveDirectoryDefault: if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying) { - fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken; + _fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken; } else { authParamsBuilder.WithUserId(ConnectionOptions.UserID); - fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken(); - _activeDirectoryAuthTimeoutRetryHelper.CachedToken = fedAuthToken; + _fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken(); + _activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken; } break; case SqlAuthenticationMethod.ActiveDirectoryPassword: case SqlAuthenticationMethod.ActiveDirectoryServicePrincipal: if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying) { - fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken; + _fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken; } else { @@ -2840,22 +2847,53 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo) { username = _credential.UserId; authParamsBuilder.WithUserId(username).WithPassword(_credential.Password); - fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken(); + _fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken(); } else { username = ConnectionOptions.UserID; authParamsBuilder.WithUserId(username).WithPassword(ConnectionOptions.Password); - fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken(); + _fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken(); } - _activeDirectoryAuthTimeoutRetryHelper.CachedToken = fedAuthToken; + _activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken; } break; default: - throw SQL.UnsupportedAuthenticationSpecified(ConnectionOptions.Authentication); + if (_accessTokenCallback == null) + { + throw SQL.UnsupportedAuthenticationSpecified(ConnectionOptions.Authentication); + } + + if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying) + { + _fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken; + } + else + { + if (_credential != null) + { + username = _credential.UserId; + authParamsBuilder.WithUserId(username).WithPassword(_credential.Password); + } + else + { + authParamsBuilder.WithUserId(ConnectionOptions.UserID); + authParamsBuilder.WithPassword(ConnectionOptions.Password); + } + SqlAuthenticationParameters parameters = authParamsBuilder; + CancellationTokenSource cts = new(); + // Use Connection timeout value to cancel token acquire request after certain period of time.(int) + if (_timeout.MillisecondsRemaining < Int32.MaxValue) + { + cts.CancelAfter((int)_timeout.MillisecondsRemaining); + } + _fedAuthToken = Task.Run(async () => await _accessTokenCallback(parameters, cts.Token)).GetAwaiter().GetResult().ToSqlFedAuthToken(); + _activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken; + } + break; } - Debug.Assert(fedAuthToken.accessToken != null, "AccessToken should not be null."); + Debug.Assert(_fedAuthToken.accessToken != null, "AccessToken should not be null."); #if DEBUG if (_forceMsalRetry) { @@ -2923,17 +2961,17 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo) } } - Debug.Assert(fedAuthToken != null, "fedAuthToken should not be null."); - Debug.Assert(fedAuthToken.accessToken != null && fedAuthToken.accessToken.Length > 0, "fedAuthToken.accessToken should not be null or empty."); + Debug.Assert(_fedAuthToken != null, "fedAuthToken should not be null."); + Debug.Assert(_fedAuthToken.accessToken != null && _fedAuthToken.accessToken.Length > 0, "fedAuthToken.accessToken should not be null or empty."); // Store the newly generated token in _newDbConnectionPoolAuthenticationContext, only if using pooling. if (_dbConnectionPool != null) { - DateTime expirationTime = DateTime.FromFileTimeUtc(fedAuthToken.expirationFileTime); - _newDbConnectionPoolAuthenticationContext = new DbConnectionPoolAuthenticationContext(fedAuthToken.accessToken, expirationTime); + DateTime expirationTime = DateTime.FromFileTimeUtc(_fedAuthToken.expirationFileTime); + _newDbConnectionPoolAuthenticationContext = new DbConnectionPoolAuthenticationContext(_fedAuthToken.accessToken, expirationTime); } SqlClientEventSource.Log.TryTraceEvent(" {0}, Finished generating federated authentication token.", ObjectID); - return fedAuthToken; + return _fedAuthToken; } internal void OnFeatureExtAck(int featureId, byte[] data) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index c407e1c6e9..16b2431a8a 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -1487,7 +1487,10 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake( } // Validate Certificate if Trust Server Certificate=false and Encryption forced (EncryptionOptions.ON) from Server. - bool shouldValidateServerCert = (_encryptionOption == EncryptionOptions.ON && !trustServerCert) || ((authType != SqlAuthenticationMethod.NotSpecified || _connHandler._accessTokenInBytes != null) && !trustServerCert); + bool shouldValidateServerCert = (_encryptionOption == EncryptionOptions.ON && !trustServerCert) || + ((authType != SqlAuthenticationMethod.NotSpecified || (_connHandler._accessTokenInBytes != null || + _connHandler._accessTokenCallback != null)) + && !trustServerCert); UInt32 info = (shouldValidateServerCert ? TdsEnums.SNI_SSL_VALIDATE_CERTIFICATE : 0) | (is2005OrLater && (_encryptionOption & EncryptionOptions.CLIENT_CERT) == 0 ? TdsEnums.SNI_SSL_USE_SCHANNEL_CACHE : 0); @@ -1551,7 +1554,7 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake( // Or AccessToken is not null, mean token based authentication is used. if ((_connHandler.ConnectionOptions != null && _connHandler.ConnectionOptions.Authentication != SqlAuthenticationMethod.NotSpecified) - || _connHandler._accessTokenInBytes != null) + || _connHandler._accessTokenInBytes != null || _connHandler._accessTokenCallback != null) { fedAuthRequired = payload[payloadOffset] == 0x01 ? true : false; } @@ -8744,7 +8747,14 @@ internal int WriteFedAuthFeatureRequest(FederatedAuthenticationFeatureExtensionD workflow = TdsEnums.MSALWORKFLOW_ACTIVEDIRECTORYDEFAULT; break; default: - Debug.Assert(false, "Unrecognized Authentication type for fedauth MSAL request"); + if (_connHandler._accessTokenCallback != null) + { + workflow = TdsEnums.MSALWORKFLOW_ACTIVEDIRECTORYTOKENCREDENTIAL; + } + else + { + Debug.Assert(false, "Unrecognized Authentication type for fedauth MSAL request"); + } break; } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs index b30f54f8e5..5fd4a54382 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs @@ -906,6 +906,15 @@ internal static string ADP_InvalidMixedUsageOfAccessTokenAndIntegratedSecurity { } } + /// + /// Looks up a localized string similar to Cannot set the AccessToken property if the AccessTokenCallback has been set.. + /// + internal static string ADP_InvalidMixedUsageOfAccessTokenAndTokenCallback { + get { + return ResourceManager.GetString("ADP_InvalidMixedUsageOfAccessTokenAndTokenCallback", resourceCulture); + } + } + /// /// Looks up a localized string similar to Cannot set the AccessToken property if 'UserID', 'UID', 'Password', or 'PWD' has been specified in connection string.. /// @@ -915,6 +924,24 @@ internal static string ADP_InvalidMixedUsageOfAccessTokenAndUserIDPassword { } } + /// + /// Looks up a localized string similar to Cannot set the AccessTokenCallback property if the 'Integrated Security' connection string keyword has been set to 'true' or 'SSPI'.. + /// + internal static string ADP_InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity { + get { + return ResourceManager.GetString("ADP_InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Cannot set the AccessTokenCallback property if 'Authentication=Active Directory Default' has been specified in the connection string.. + /// + internal static string ADP_InvalidMixedUsageOfAuthenticationAndTokenCallback { + get { + return ResourceManager.GetString("ADP_InvalidMixedUsageOfAuthenticationAndTokenCallback", resourceCulture); + } + } + /// /// Looks up a localized string similar to Cannot set the Credential property if the AccessToken property is already set.. /// @@ -5559,6 +5586,15 @@ internal static string DbConnectionString_ApplicationName { } } + /// + /// Looks up a localized string similar to When true, enables usage of the Asynchronous functionality in the .NET Framework Data Provider.. + /// + internal static string DbConnectionString_AsynchronousProcessing { + get { + return ResourceManager.GetString("DbConnectionString_AsynchronousProcessing", resourceCulture); + } + } + /// /// Looks up a localized string similar to The name of the primary file, including the full path name, of an attachable database.. /// @@ -8910,6 +8946,15 @@ internal static string SQL_ArgumentLengthMismatch { } } + /// + /// Looks up a localized string similar to This command requires an asynchronous connection. Set "Asynchronous Processing=true" in the connection string.. + /// + internal static string SQL_AsyncConnectionRequired { + get { + return ResourceManager.GetString("SQL_AsyncConnectionRequired", resourceCulture); + } + } + /// /// Looks up a localized string similar to The asynchronous operation has already completed.. /// @@ -10593,6 +10638,15 @@ internal static string SqlConnection_AccessToken { } } + /// + /// Looks up a localized string similar to State of connection, synchronous or asynchronous. 'Asynchronous Processing=x' in the connection string.. + /// + internal static string SqlConnection_Asynchronous { + get { + return ResourceManager.GetString("SqlConnection_Asynchronous", resourceCulture); + } + } + /// /// Looks up a localized string similar to A guid to represent the physical connection.. /// diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx index 85627eae31..2b2374e490 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx +++ b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx @@ -4635,4 +4635,13 @@ The '{0}' platform is not supported when targeting .NET Framework. - + + Cannot set the AccessToken property if the AccessTokenCallback has been set. + + + Cannot set the AccessTokenCallback property if the 'Integrated Security' connection string keyword has been set to 'true' or 'SSPI'. + + + Cannot set the AccessTokenCallback property if 'Authentication=Active Directory Default' has been specified in the connection string. + + \ No newline at end of file diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.Windows.cs index 9b9c0f3341..abeadf86b7 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.Windows.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.Windows.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Runtime.InteropServices; using System.Runtime.Versioning; using System.Security; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs index 8f48131f8d..08e46ed284 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs @@ -840,7 +840,8 @@ private static string ConnectionStateMsg(ConnectionState state) { // MDAC 82165, if the ConnectionState enum to msg the localization looks weird return state switch { - (ConnectionState.Closed) or (ConnectionState.Connecting | ConnectionState.Broken) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Closed), + (ConnectionState.Closed) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Closed), + (ConnectionState.Connecting | ConnectionState.Broken) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Closed), (ConnectionState.Connecting) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Connecting), (ConnectionState.Open) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Open), (ConnectionState.Open | ConnectionState.Executing) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_OpenExecuting), @@ -1267,7 +1268,16 @@ static internal InvalidOperationException InvalidMixedUsageOfAccessTokenAndAuthe static internal Exception InvalidMixedUsageOfCredentialAndAccessToken() => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfCredentialAndAccessToken)); -#endregion + + static internal Exception InvalidMixedUsageOfAccessTokenAndTokenCallback() + => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenAndTokenCallback)); + + internal static Exception InvalidMixedUsageOfAccessTokenCallbackAndAuthentication() + => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAuthenticationAndTokenCallback)); + + internal static Exception InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity() + => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity)); + #endregion internal static bool IsEmpty(string str) => string.IsNullOrEmpty(str); internal static readonly IntPtr s_ptrZero = IntPtr.Zero; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionPoolKey.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionPoolKey.cs index c35ce6f08e..1eed6a229d 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionPoolKey.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionPoolKey.cs @@ -2,7 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; using Microsoft.Data.Common; namespace Microsoft.Data.SqlClient @@ -14,9 +17,11 @@ internal class SqlConnectionPoolKey : DbConnectionPoolKey private int _hashValue; private readonly SqlCredential _credential; private readonly string _accessToken; + private Func> _accessTokenCallback; internal SqlCredential Credential => _credential; internal string AccessToken => _accessToken; + internal Func> AccessTokenCallback => _accessTokenCallback; internal override string ConnectionString { @@ -48,11 +53,13 @@ internal SqlConnectionPoolKey(string connectionString, string accessToken, ServerCertificateValidationCallback serverCertificateValidationCallback, ClientCertificateRetrievalCallback clientCertificateRetrievalCallback, - SqlClientOriginalNetworkAddressInfo originalNetworkAddressInfo) : base(connectionString) + SqlClientOriginalNetworkAddressInfo originalNetworkAddressInfo, + Func> accessTokenCallback = null) : base(connectionString) { - Debug.Assert(_credential == null || _accessToken == null, "Credential and AccessToken can't have the value at the same time."); + Debug.Assert(_credential == null || _accessToken == null || accessTokenCallback == null, "Credential, AccessToken, and Callback can't have a value at the same time."); _credential = credential; _accessToken = accessToken; + _accessTokenCallback = accessTokenCallback; _serverCertificateValidationCallback = serverCertificateValidationCallback; _clientCertificateRetrievalCallback = clientCertificateRetrievalCallback; _originalNetworkAddressInfo = originalNetworkAddressInfo; @@ -61,11 +68,12 @@ internal SqlConnectionPoolKey(string connectionString, #endregion #else #region NET Core - internal SqlConnectionPoolKey(string connectionString, SqlCredential credential, string accessToken) : base(connectionString) + internal SqlConnectionPoolKey(string connectionString, SqlCredential credential, string accessToken, Func> accessTokenCallback) : base(connectionString) { - Debug.Assert(_credential == null || _accessToken == null, "Credential and AccessToken can't have the value at the same time."); + Debug.Assert(credential == null || accessToken == null || accessTokenCallback == null, "Credential, AccessToken, and Callback can't have a value at the same time."); _credential = credential; _accessToken = accessToken; + _accessTokenCallback = accessTokenCallback; CalculateHashCode(); } #endregion @@ -75,6 +83,7 @@ private SqlConnectionPoolKey(SqlConnectionPoolKey key) : base(key) { _credential = key.Credential; _accessToken = key.AccessToken; + _accessTokenCallback = key._accessTokenCallback; #if NETFRAMEWORK _serverCertificateValidationCallback = key._serverCertificateValidationCallback; _clientCertificateRetrievalCallback = key._clientCertificateRetrievalCallback; @@ -92,6 +101,7 @@ public override bool Equals(object obj) return (obj is SqlConnectionPoolKey key && _credential == key._credential && ConnectionString == key.ConnectionString + && _accessTokenCallback == key._accessTokenCallback && string.CompareOrdinal(_accessToken, key._accessToken) == 0 #if NETFRAMEWORK && _serverCertificateValidationCallback == key._serverCertificateValidationCallback @@ -124,6 +134,13 @@ private void CalculateHashCode() _hashValue = _hashValue * 17 + _accessToken.GetHashCode(); } } + else if (_accessTokenCallback != null) + { + unchecked + { + _hashValue = _hashValue * 17 + _accessTokenCallback.GetHashCode(); + } + } #if NETFRAMEWORK if (_originalNetworkAddressInfo != null) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs index 81c2ed1570..8a8cb3772d 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs @@ -272,6 +272,7 @@ public enum FedAuthLibrary : byte public const byte MSALWORKFLOW_ACTIVEDIRECTORYDEVICECODEFLOW = 0x03; // Using the Interactive byte as that is the closest we have public const byte MSALWORKFLOW_ACTIVEDIRECTORYMANAGEDIDENTITY = 0x03; // Using the Interactive byte as that's supported for Identity based authentication public const byte MSALWORKFLOW_ACTIVEDIRECTORYDEFAULT = 0x03; // Using the Interactive byte as that is the closest we have to non-password based authentication modes + public const byte MSALWORKFLOW_ACTIVEDIRECTORYTOKENCREDENTIAL = 0x03; // Using the Interactive byte as that is the closest we have to non-password based authentication modes public enum ActiveDirectoryWorkflow : byte { diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs index e61bdfe4ac..03e6e7b5c7 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs @@ -7,6 +7,7 @@ using System.Data.Common; using System.Reflection; using System.Security; +using System.Threading; using System.Threading.Tasks; using Microsoft.SqlServer.TDS.Servers; using Xunit; @@ -264,5 +265,86 @@ public void ConnectionTestValidCredentialCombination() Assert.Equal(sqlCredential, conn.Credential); } + + [Fact] + public void ConnectionTestAccessTokenCallbackCombinations() + { + var cleartextCredsConnStr = "User=test;Password=test;"; + var sspiConnStr = "Integrated Security=true;"; + var authConnStr = "Authentication=ActiveDirectoryPassword"; + var testPassword = new SecureString(); + testPassword.MakeReadOnly(); + var sqlCredential = new SqlCredential(string.Empty, testPassword); + Func> callback = (ctx, token) => + Task.FromResult(new SqlAuthenticationToken("invalid", DateTimeOffset.MaxValue)); + + // Successes + using (var conn = new SqlConnection(cleartextCredsConnStr)) + { + conn.AccessTokenCallback = callback; + conn.AccessTokenCallback = null; + } + + using (var conn = new SqlConnection(string.Empty, sqlCredential)) + { + conn.AccessTokenCallback = null; + conn.AccessTokenCallback = callback; + } + + using (var conn = new SqlConnection() + { + AccessTokenCallback = callback + }) + { + conn.Credential = sqlCredential; + } + + using (var conn = new SqlConnection() + { + AccessTokenCallback = callback + }) + { + conn.ConnectionString = cleartextCredsConnStr; + } + + //Failures + using (var conn = new SqlConnection(sspiConnStr)) + { + Assert.Throws(() => + { + conn.AccessTokenCallback = callback; + }); + } + + using (var conn = new SqlConnection(authConnStr)) + { + Assert.Throws(() => + { + conn.AccessTokenCallback = callback; + }); + } + + using (var conn = new SqlConnection() + { + AccessTokenCallback = callback + }) + { + Assert.Throws(() => + { + conn.ConnectionString = sspiConnStr; + }); + } + + using (var conn = new SqlConnection() + { + AccessTokenCallback = callback + }) + { + Assert.Throws(() => + { + conn.ConnectionString = authConnStr; + }); + } + } } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs index 13f607583d..598fd1c8d8 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs @@ -203,7 +203,7 @@ public static IEnumerable GetConnectionStrings(bool withEnclave) private static string GenerateAccessToken(string authorityURL, string aADAuthUserID, string aADAuthPassword) { - return AcquireTokenAsync(authorityURL, aADAuthUserID, aADAuthPassword).Result; + return AcquireTokenAsync(authorityURL, aADAuthUserID, aADAuthPassword).GetAwaiter().GetResult(); } private static Task AcquireTokenAsync(string authorityURL, string userID, string password) => Task.Run(() => diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs index 70cae84d73..52e5705140 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs @@ -7,6 +7,8 @@ using System.Security; using System.Threading; using System.Threading.Tasks; +using Azure.Core; +using Azure.Identity; using Microsoft.Identity.Client; using Xunit; @@ -551,6 +553,77 @@ public static void ActiveDirectoryDefaultWithPasswordMustFail() Assert.Contains(expectedMessage, e.Message); } + [ConditionalFact(nameof(IsAADConnStringsSetup))] + public static void ActiveDirectoryDefaultWithAccessTokenCallbackMustFail() + { + // connection fails with expected error message. + string[] credKeys = { "Authentication", "User ID", "Password", "UID", "PWD" }; + string connStrWithNoCred = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.AADPasswordConnectionString, credKeys) + + "Authentication=ActiveDirectoryDefault"; + InvalidOperationException e = Assert.Throws(() => + { + using (SqlConnection conn = new SqlConnection(connStrWithNoCred)) + { + conn.AccessTokenCallback = (ctx, token) => + Task.FromResult(new SqlAuthenticationToken("my token", DateTimeOffset.MaxValue)); + conn.Open(); + + Assert.NotEqual(System.Data.ConnectionState.Open, conn.State); + } + }); + + string expectedMessage = "Cannot set the AccessTokenCallback property if 'Authentication=Active Directory Default' has been specified in the connection string."; + Assert.Contains(expectedMessage, e.Message); + } + + [ConditionalFact(nameof(IsAADConnStringsSetup))] + public static void AccessTokenCallbackMustOpenPassAndChangePropertyFail() + { + string[] credKeys = { "Authentication", "User ID", "Password", "UID", "PWD" }; + string connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.AADPasswordConnectionString, credKeys); + var cred = new DefaultAzureCredential(); + const string defaultScopeSuffix = "/.default"; + using (SqlConnection conn = new SqlConnection(connStr)) + { + conn.AccessTokenCallback = (ctx, cancellationToken) => + { + string scope = ctx.Resource.EndsWith(defaultScopeSuffix) ? ctx.Resource : ctx.Resource + defaultScopeSuffix; + AccessToken token = cred.GetToken(new TokenRequestContext(new[] { scope }), cancellationToken); + return Task.FromResult(new SqlAuthenticationToken(token.Token, token.ExpiresOn)); + }; + conn.Open(); + Assert.Equal(System.Data.ConnectionState.Open, conn.State); + + InvalidOperationException ex = Assert.Throws(() => conn.AccessTokenCallback = null); + string expectedMessage = "Not allowed to change the 'AccessTokenCallback' property. The connection's current state is open."; + Assert.Contains(expectedMessage, ex.Message); + } + } + + [ConditionalFact(nameof(IsAADConnStringsSetup))] + public static void AccessTokenCallbackReceivesUsernameAndPassword() + { + var userId = "someuser"; + var pwd = "somepassword"; + string[] credKeys = { "Authentication", "User ID", "Password", "UID", "PWD" }; + string connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.AADPasswordConnectionString, credKeys) + + $"User ID={userId}; Password={pwd}"; + var cred = new DefaultAzureCredential(); + const string defaultScopeSuffix = "/.default"; + using (SqlConnection conn = new SqlConnection(connStr)) + { + conn.AccessTokenCallback = (parms, cancellationToken) => + { + Assert.Equal(userId, parms.UserId); + Assert.Equal(pwd, parms.Password); + string scope = parms.Resource.EndsWith(defaultScopeSuffix) ? parms.Resource : parms.Resource + defaultScopeSuffix; + AccessToken token = cred.GetToken(new TokenRequestContext(new[] { scope }), cancellationToken); + return Task.FromResult(new SqlAuthenticationToken(token.Token, token.ExpiresOn)); + }; + conn.Open(); + } + } + [ConditionalFact(nameof(IsAADConnStringsSetup))] public static void ActiveDirectoryDefaultMustPass() {