diff --git a/doc/samples/SqlConnection_AccessTokenCallback.cs b/doc/samples/SqlConnection_AccessTokenCallback.cs
new file mode 100644
index 0000000000..c761f640fd
--- /dev/null
+++ b/doc/samples/SqlConnection_AccessTokenCallback.cs
@@ -0,0 +1,35 @@
+using System;
+using System.Data;
+//
+using Microsoft.Data.SqlClient;
+using Azure.Identity;
+
+class Program
+{
+ static void Main()
+ {
+ OpenSqlConnection();
+ Console.ReadLine();
+ }
+
+ private static void OpenSqlConnection()
+ {
+ string connectionString = GetConnectionString();
+ using (SqlConnection connection = new SqlConnection("Data Source=contoso.database.windows.net;Initial Catalog=AdventureWorks;")
+ {
+ AccessTokenCallback = async (authParams, cancellationToken) =>
+ {
+ var cred = new DefaultAzureCredential();
+ string scope = authParams.Resource.EndsWith(s_defaultScopeSuffix) ? authParams.Resource : authParams.Resource + s_defaultScopeSuffix;
+ var token = await cred.GetTokenAsync(new TokenRequestContext(new[] { scope }), cancellationToken);
+ return new SqlAuthenticationToken(token.Token, token.ExpiresOn);
+ }
+ })
+ {
+ connection.Open();
+ Console.WriteLine("ServerVersion: {0}", connection.ServerVersion);
+ Console.WriteLine("State: {0}", connection.State);
+ }
+ }
+}
+//
diff --git a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml
index bfc53dddac..08a9cc88fd 100644
--- a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml
+++ b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml
@@ -129,6 +129,22 @@ using (SqlConnection connection = new SqlConnection(connectionString))
The access token for the connection.
To be added.
+
+ Gets or sets the access token callback for the connection.
+
+ The Func that takes a and and returns a .
+
+ .
+
+ [!code-csharp[SqlConnection_AccessTokenCallback Example#1](~/../sqlclient/doc/samples/SqlConnection_AccessTokenCallback.cs#1)]
+
+ ]]>
+
+ The AccessTokenCallback is combined with other conflicting authentication configurations.
+
To be added.
To be added.
diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs
index f10c66b386..92cb042b76 100644
--- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs
@@ -4,6 +4,7 @@
// NOTE: The current Microsoft.VSDesigner editor attributes are implemented for System.Data.SqlClient, and are not publicly available.
// New attributes that are designed to work with Microsoft.Data.SqlClient and are publicly documented should be included in future.
+
[assembly: System.CLSCompliant(true)]
namespace Microsoft.Data
{
@@ -839,6 +840,8 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collect
///
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
public System.Guid ClientConnectionId { get { throw null; } }
+ ///
+ public System.Func> AccessTokenCallback { get { throw null; } set { } }
///
/// for internal test only
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs
index dd23098d8f..cf13db3389 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs
@@ -87,6 +87,8 @@ private static readonly Dictionary
/// Instance-level list of custom key store providers. It can be set more than once by the user.
private IReadOnlyDictionary _customColumnEncryptionKeyStoreProviders;
+ private Func> _accessTokenCallback;
+
internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0;
@@ -272,7 +274,7 @@ internal static List GetColumnEncryptionSystemKeyStoreProvidersNames()
}
///
- /// This function returns a list of the names of the custom providers currently registered. If the
+ /// This function returns a list of the names of the custom providers currently registered. If the
/// instance-level cache is not empty, that cache is used, else the global cache is used.
///
/// Combined list of provider names
@@ -344,7 +346,7 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(IDictionary
+ public Func> AccessTokenCallback
+ {
+ get { return _accessTokenCallback; }
+ set
+ {
+ // If a connection is connecting or is ever opened, AccessToken callback cannot be set
+ if (!InnerConnection.AllowSetConnectionString)
+ {
+ throw ADP.OpenConnectionPropertySet(nameof(AccessTokenCallback), InnerConnection.State);
+ }
+
+ if (value != null)
+ {
+ // Check if the usage of AccessToken has any conflict with the keys used in connection string and credential
+ CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback((SqlConnectionString)ConnectionOptions);
+ }
+
+ ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: value));
+ _accessTokenCallback = value;
+ }
+ }
+
///
[ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)]
[ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)]
@@ -970,6 +1001,7 @@ public SqlCredential Credential
}
CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions);
+
if (_accessToken != null)
{
throw ADP.InvalidMixedUsageOfCredentialAndAccessToken();
@@ -979,7 +1011,7 @@ public SqlCredential Credential
_credential = value;
// Need to call ConnectionString_Set to do proper pool group check
- ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken));
+ ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback));
}
}
@@ -1026,6 +1058,33 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(S
{
throw ADP.InvalidMixedUsageOfCredentialAndAccessToken();
}
+
+ if(_accessTokenCallback != null)
+ {
+ throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
+ }
+ }
+
+ // CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback: check if the usage of AccessTokenCallback has any conflict
+ // with the keys used in connection string and credential
+ // If there is any conflict, it throws InvalidOperationException
+ // This is to be used setter of ConnectionString and AccessTokenCallback properties
+ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(SqlConnectionString connectionOptions)
+ {
+ if (UsesIntegratedSecurity(connectionOptions))
+ {
+ throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity();
+ }
+
+ if (UsesAuthentication(connectionOptions))
+ {
+ throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndAuthentication();
+ }
+
+ if(_accessToken != null)
+ {
+ throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
+ }
}
///
@@ -2128,7 +2187,7 @@ public static void ChangePassword(string connectionString, string newPassword)
throw ADP.InvalidArgumentLength(nameof(newPassword), TdsEnums.MAXLEN_NEWPASSWORD);
}
- SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null);
+ SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null);
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
if (connectionOptions.IntegratedSecurity)
@@ -2177,7 +2236,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent
throw ADP.InvalidArgumentLength(nameof(newSecurePassword), TdsEnums.MAXLEN_NEWPASSWORD);
}
- SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null);
+ SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
@@ -2216,7 +2275,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString
if (con != null)
con.Dispose();
}
- SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null);
+ SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
SqlConnectionFactory.SingletonInstance.ClearPool(key);
}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs
index 93450aec7b..7526f38623 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs
@@ -133,7 +133,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
opt = new SqlConnectionString(opt, instanceName, userInstance: false, setEnlistValue: null);
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
}
- return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool);
+ return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool, key.AccessTokenCallback);
}
protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous)
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
index 66631151c9..cc727103df 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
@@ -130,6 +130,7 @@ internal sealed class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken _fedAuthToken = null;
internal byte[] _accessTokenInBytes;
+ internal readonly Func> _accessTokenCallback;
private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;
private readonly SqlAuthenticationProviderManager _sqlAuthenticationProviderManager;
@@ -434,19 +435,20 @@ internal SqlConnectionTimeoutErrorInternal TimeoutErrorInternal
// the new Login7 packet will always write out the new password (or a length of zero and no bytes if not present)
//
internal SqlInternalConnectionTds(
- DbConnectionPoolIdentity identity,
- SqlConnectionString connectionOptions,
- SqlCredential credential,
- object providerInfo,
- string newPassword,
- SecureString newSecurePassword,
- bool redirectedUserInstance,
- SqlConnectionString userConnectionOptions = null, // NOTE: userConnectionOptions may be different to connectionOptions if the connection string has been expanded (see SqlConnectionString.Expand)
- SessionData reconnectSessionData = null,
- bool applyTransientFaultHandling = false,
- string accessToken = null,
- DbConnectionPool pool = null
- ) : base(connectionOptions)
+ DbConnectionPoolIdentity identity,
+ SqlConnectionString connectionOptions,
+ SqlCredential credential,
+ object providerInfo,
+ string newPassword,
+ SecureString newSecurePassword,
+ bool redirectedUserInstance,
+ SqlConnectionString userConnectionOptions = null, // NOTE: userConnectionOptions may be different to connectionOptions if the connection string has been expanded (see SqlConnectionString.Expand)
+ SessionData reconnectSessionData = null,
+ bool applyTransientFaultHandling = false,
+ string accessToken = null,
+ DbConnectionPool pool = null,
+ Func> accessTokenCallback = null) : base(connectionOptions)
{
#if DEBUG
@@ -479,6 +481,8 @@ internal SqlInternalConnectionTds(
_accessTokenInBytes = System.Text.Encoding.Unicode.GetBytes(accessToken);
}
+ _accessTokenCallback = accessTokenCallback;
+
_activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper();
_sqlAuthenticationProviderManager = SqlAuthenticationProviderManager.Instance;
@@ -1327,7 +1331,8 @@ private void Login(ServerInfo server, TimeoutTimer timeout, string newPassword,
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryMSI
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryDefault
// Since AD Integrated may be acting like Windows integrated, additionally check _fedAuthRequired
- || (ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated && _fedAuthRequired))
+ || (ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated && _fedAuthRequired)
+ || _accessTokenCallback != null)
{
requestedFeatures |= TdsEnums.FeatureExtension.FedAuth;
_federatedAuthenticationInfoRequested = true;
@@ -2152,6 +2157,7 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
{
Debug.Assert((ConnectionOptions._hasUserIdKeyword && ConnectionOptions._hasPasswordKeyword)
|| _credential != null
+ || _accessTokenCallback != null
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity
@@ -2354,7 +2360,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
string username = null;
var authProvider = _sqlAuthenticationProviderManager.GetProvider(ConnectionOptions.Authentication);
- if (authProvider == null)
+ if (authProvider == null && _accessTokenCallback == null)
throw SQL.CannotFindAuthProvider(ConnectionOptions.Authentication.ToString());
// retry getting access token once if MsalException.error_code is unknown_error.
@@ -2365,11 +2371,11 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
try
{
var authParamsBuilder = new SqlAuthenticationParameters.Builder(
- authenticationMethod: ConnectionOptions.Authentication,
- resource: fedAuthInfo.spn,
- authority: fedAuthInfo.stsurl,
- serverName: ConnectionOptions.DataSource,
- databaseName: ConnectionOptions.InitialCatalog)
+ authenticationMethod: ConnectionOptions.Authentication,
+ resource: fedAuthInfo.spn,
+ authority: fedAuthInfo.stsurl,
+ serverName: ConnectionOptions.DataSource,
+ databaseName: ConnectionOptions.InitialCatalog)
.WithConnectionId(_clientConnectionId)
.WithConnectionTimeout(ConnectionOptions.ConnectTimeout);
switch (ConnectionOptions.Authentication)
@@ -2439,7 +2445,38 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
}
break;
default:
- throw SQL.UnsupportedAuthenticationSpecified(ConnectionOptions.Authentication);
+ if (_accessTokenCallback == null)
+ {
+ throw SQL.UnsupportedAuthenticationSpecified(ConnectionOptions.Authentication);
+ }
+
+ if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying)
+ {
+ _fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
+ }
+ else
+ {
+ if (_credential != null)
+ {
+ username = _credential.UserId;
+ authParamsBuilder.WithUserId(username).WithPassword(_credential.Password);
+ }
+ else
+ {
+ authParamsBuilder.WithUserId(ConnectionOptions.UserID);
+ authParamsBuilder.WithPassword(ConnectionOptions.Password);
+ }
+ SqlAuthenticationParameters parameters = authParamsBuilder;
+ CancellationTokenSource cts = new();
+ // Use Connection timeout value to cancel token acquire request after certain period of time.(int)
+ if (_timeout.MillisecondsRemaining < Int32.MaxValue)
+ {
+ cts.CancelAfter((int)_timeout.MillisecondsRemaining);
+ }
+ _fedAuthToken = Task.Run(async () => await _accessTokenCallback(parameters, cts.Token)).GetAwaiter().GetResult().ToSqlFedAuthToken();
+ _activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
+ }
+ break;
}
Debug.Assert(_fedAuthToken.accessToken != null, "AccessToken should not be null.");
@@ -2488,17 +2525,21 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
// Deal with normal MsalExceptions.
catch (MsalException msalException)
{
- if (MsalError.UnknownError != msalException.ErrorCode
- || _timeout.IsExpired
- || _timeout.MillisecondsRemaining <= sleepInterval)
+ if (MsalError.UnknownError != msalException.ErrorCode || _timeout.IsExpired || _timeout.MillisecondsRemaining <= sleepInterval)
{
SqlClientEventSource.Log.TryTraceEvent(" {0}", msalException.ErrorCode);
throw ADP.CreateSqlException(msalException, ConnectionOptions, this, username);
}
- SqlClientEventSource.Log.TryAdvancedTraceEvent(" {0}, sleeping {1}[Milliseconds]", ObjectID, sleepInterval);
- SqlClientEventSource.Log.TryAdvancedTraceEvent(" {0}, remaining {1}[Milliseconds]", ObjectID, _timeout.MillisecondsRemaining);
+ SqlClientEventSource.Log.TryAdvancedTraceEvent(
+ " {0}, sleeping {1}[Milliseconds]",
+ ObjectID,
+ sleepInterval);
+ SqlClientEventSource.Log.TryAdvancedTraceEvent(
+ " {0}, remaining {1}[Milliseconds]",
+ ObjectID,
+ _timeout.MillisecondsRemaining);
Thread.Sleep(sleepInterval);
sleepInterval *= 2;
@@ -2506,7 +2547,21 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
// All other exceptions from MSAL/Azure Identity APIs
catch (Exception e)
{
- throw SqlException.CreateException(new() { new(0, (byte)0x00, (byte)TdsEnums.FATAL_ERROR_CLASS, ConnectionOptions.DataSource, e.Message, ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0) }, "", this, e);
+ throw SqlException.CreateException(
+ new()
+ {
+ new(
+ 0,
+ (byte)0x00,
+ (byte)TdsEnums.FATAL_ERROR_CLASS,
+ ConnectionOptions.DataSource,
+ e.Message,
+ ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName,
+ 0)
+ },
+ "",
+ this,
+ e);
}
}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
index e3de792213..85a4159ca7 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
@@ -1052,7 +1052,8 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake(
{
// Validate Certificate if Trust Server Certificate=false and Encryption forced (EncryptionOptions.ON) from Server.
bool shouldValidateServerCert = (_encryptionOption == EncryptionOptions.ON && !trustServerCert) ||
- (_connHandler._accessTokenInBytes != null && !trustServerCert);
+ ((_connHandler._accessTokenInBytes != null || _connHandler._accessTokenCallback != null)
+ && !trustServerCert);
uint info = (shouldValidateServerCert ? TdsEnums.SNI_SSL_VALIDATE_CERTIFICATE : 0)
| (is2005OrLater ? TdsEnums.SNI_SSL_USE_SCHANNEL_CACHE : 0);
@@ -1114,7 +1115,7 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake(
// Or AccessToken is not null, mean token based authentication is used.
if ((_connHandler.ConnectionOptions != null
&& _connHandler.ConnectionOptions.Authentication != SqlAuthenticationMethod.NotSpecified)
- || _connHandler._accessTokenInBytes != null)
+ || _connHandler._accessTokenInBytes != null || _connHandler._accessTokenCallback != null)
{
fedAuthRequired = payload[payloadOffset] == 0x01 ? true : false;
}
@@ -7925,7 +7926,14 @@ internal int WriteFedAuthFeatureRequest(FederatedAuthenticationFeatureExtensionD
workflow = TdsEnums.MSALWORKFLOW_ACTIVEDIRECTORYDEFAULT;
break;
default:
- Debug.Assert(false, "Unrecognized Authentication type for fedauth MSAL request");
+ if (_connHandler._accessTokenCallback != null)
+ {
+ workflow = TdsEnums.MSALWORKFLOW_ACTIVEDIRECTORYTOKENCREDENTIAL;
+ }
+ else
+ {
+ Debug.Assert(false, "Unrecognized Authentication type for fedauth MSAL request");
+ }
break;
}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs b/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs
index f6fbcf8306..cf900c4553 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs
@@ -429,6 +429,15 @@ internal static string ADP_InvalidMixedUsageOfAccessTokenAndIntegratedSecurity {
}
}
+ ///
+ /// Looks up a localized string similar to Cannot set the AccessToken property if the AccessTokenCallback has been set..
+ ///
+ internal static string ADP_InvalidMixedUsageOfAccessTokenAndTokenCallback {
+ get {
+ return ResourceManager.GetString("ADP_InvalidMixedUsageOfAccessTokenAndTokenCallback", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Cannot set the AccessToken property if 'UserID', 'UID', 'Password', or 'PWD' has been specified in connection string..
///
@@ -438,6 +447,24 @@ internal static string ADP_InvalidMixedUsageOfAccessTokenAndUserIDPassword {
}
}
+ ///
+ /// Looks up a localized string similar to Cannot set the AccessTokenCallback property if the 'Integrated Security' connection string keyword has been set to 'true' or 'SSPI'..
+ ///
+ internal static string ADP_InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity {
+ get {
+ return ResourceManager.GetString("ADP_InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Cannot set the AccessTokenCallback property if 'Authentication=Active Directory Default' has been specified in the connection string..
+ ///
+ internal static string ADP_InvalidMixedUsageOfAuthenticationAndTokenCallback {
+ get {
+ return ResourceManager.GetString("ADP_InvalidMixedUsageOfAuthenticationAndTokenCallback", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Cannot set the Credential property if the AccessToken property is already set..
///
@@ -950,7 +977,7 @@ internal static string Data_InvalidOffsetLength {
return ResourceManager.GetString("Data_InvalidOffsetLength", resourceCulture);
}
}
-
+
///
/// Looks up a localized string similar to Internal error occurred when retrying the download of the HGS root certificate after the initial request failed. Contact Customer Support Services..
///
@@ -1915,7 +1942,7 @@ internal static string SNI_ERROR_9 {
}
///
- /// Looks up a localized string similar to Incorrect physicalConnection type.
+ /// Looks up a localized string similar to Incorrect physicalConnection type..
///
internal static string SNI_IncorrectPhysicalConnectionType {
get {
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx b/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx
index 45dd4459d9..7f52b2556f 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx
@@ -1941,4 +1941,13 @@
Socket did not throw expected '{0}' with error code '{1}'.
+
+ Cannot set the AccessToken property if the AccessTokenCallback has been set.
+
+
+ Cannot set the AccessTokenCallback property if the 'Integrated Security' connection string keyword has been set to 'true' or 'SSPI'.
+
+
+ Cannot set the AccessTokenCallback property if 'Authentication=Active Directory Default' has been specified in the connection string.
+
diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs
index ca229487be..4663cc151f 100644
--- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs
@@ -4,6 +4,7 @@
// NOTE: The current Microsoft.VSDesigner editor attributes are implemented for System.Data.SqlClient, and are not publicly available.
// New attributes that are designed to work with Microsoft.Data.SqlClient and are publicly documented should be included in future.
+
[assembly: System.CLSCompliant(true)]
[assembly: System.Resources.NeutralResourcesLanguageAttribute("en-US")]
namespace Microsoft.Data
@@ -767,6 +768,8 @@ public SqlConnection(string connectionString, Microsoft.Data.SqlClient.SqlCreden
[System.ComponentModel.BrowsableAttribute(false)]
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
public string AccessToken { get { throw null; } set { } }
+ ///
+ public System.Func> AccessTokenCallback { get { throw null; } set { } }
///
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
public System.Guid ClientConnectionId { get { throw null; } }
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs
index f9c02e8efc..23893afcf6 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs
@@ -71,6 +71,8 @@ private static Dictionary s_systemC
/// Instance-level list of custom key store providers. It can be set more than once by the user.
private IReadOnlyDictionary _customColumnEncryptionKeyStoreProviders;
+ private Func> _accessTokenCallback;
+
internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0;
@@ -739,7 +741,30 @@ public string AccessToken
_accessToken = value;
// Need to call ConnectionString_Set to do proper pool group check
- ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo));
+ ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, null));
+ }
+ }
+
+ ///
+ public Func> AccessTokenCallback
+ {
+ get { return _accessTokenCallback; }
+ set
+ {
+ // If a connection is connecting or is ever opened, AccessToken callback cannot be set
+ if (!InnerConnection.AllowSetConnectionString)
+ {
+ throw ADP.OpenConnectionPropertySet(nameof(AccessTokenCallback), InnerConnection.State);
+ }
+
+ if (value != null)
+ {
+ // Check if the usage of AccessToken has any conflict with the keys used in connection string and credential
+ CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback((SqlConnectionString)ConnectionOptions);
+ }
+
+ ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, null, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, value));
+ _accessTokenCallback = value;
}
}
@@ -776,7 +801,7 @@ override public string ConnectionString
}
set
{
- if (_credential != null || _accessToken != null)
+ if (_credential != null || _accessToken != null || _accessTokenCallback != null)
{
SqlConnectionString connectionOptions = new SqlConnectionString(value);
if (_credential != null)
@@ -812,12 +837,18 @@ override public string ConnectionString
CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions);
}
- else if (_accessToken != null)
+
+ if (_accessToken != null)
{
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(connectionOptions);
}
+
+ if (_accessTokenCallback != null)
+ {
+ CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions);
+ }
}
- ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo));
+ ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, _accessTokenCallback));
_connectionString = value; // Change _connectionString value only after value is validated
CacheConnectionStringProperties();
}
@@ -1154,17 +1185,17 @@ public SqlCredential Credential
}
CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions);
+
if (_accessToken != null)
{
throw ADP.InvalidMixedUsageOfCredentialAndAccessToken();
}
-
}
_credential = value;
// Need to call ConnectionString_Set to do proper pool group check
- ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo));
+ ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, _accessTokenCallback));
}
}
@@ -1221,6 +1252,33 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(S
{
throw ADP.InvalidMixedUsageOfAccessTokenAndCredential();
}
+
+ if (_accessTokenCallback != null)
+ {
+ throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
+ }
+ }
+
+ // CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback: check if the usage of AccessTokenCallback has any conflict
+ // with the keys used in connection string and credential
+ // If there is any conflict, it throws InvalidOperationException
+ // This is to be used setter of ConnectionString and AccessTokenCallback properties
+ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(SqlConnectionString connectionOptions)
+ {
+ if (UsesIntegratedSecurity(connectionOptions))
+ {
+ throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity();
+ }
+
+ if (UsesAuthentication(connectionOptions))
+ {
+ throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndAuthentication();
+ }
+
+ if (_accessToken != null)
+ {
+ throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
+ }
}
//
@@ -2699,7 +2757,7 @@ public static void ChangePassword(string connectionString, string newPassword)
throw ADP.InvalidArgumentLength("newPassword", TdsEnums.MAXLEN_NEWPASSWORD);
}
- SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null);
+ SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null, accessTokenCallback: null);
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
if (connectionOptions.IntegratedSecurity || connectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
@@ -2755,7 +2813,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent
throw ADP.InvalidArgumentLength("newSecurePassword", TdsEnums.MAXLEN_NEWPASSWORD);
}
- SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null);
+ SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null, accessTokenCallback: null);
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
@@ -2800,7 +2858,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString
throw SQL.ChangePasswordRequires2005();
}
}
- SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null);
+ SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null, accessTokenCallback: null);
SqlConnectionFactory.SingletonInstance.ClearPool(key);
}
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs
index e2051784a3..aca7bfc151 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs
@@ -142,7 +142,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
opt = new SqlConnectionString(opt, instanceName, false /* user instance=false */, null /* do not modify the Enlist value */);
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
}
- result = new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, key.ServerCertificateValidationCallback, key.ClientCertificateRetrievalCallback, pool, key.AccessToken, key.OriginalNetworkAddressInfo, applyTransientFaultHandling: applyTransientFaultHandling);
+ result = new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, key.ServerCertificateValidationCallback, key.ClientCertificateRetrievalCallback, pool, key.AccessToken, key.OriginalNetworkAddressInfo, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessTokenCallback);
}
return result;
}
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
index da623f874e..f6b81a533c 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
@@ -136,6 +136,7 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken _fedAuthToken = null;
internal byte[] _accessTokenInBytes;
+ internal readonly Func> _accessTokenCallback;
private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;
private readonly SqlAuthenticationProviderManager _sqlAuthenticationProviderManager;
@@ -432,7 +433,9 @@ internal SqlInternalConnectionTds(
DbConnectionPool pool = null,
string accessToken = null,
SqlClientOriginalNetworkAddressInfo originalNetworkAddressInfo = null,
- bool applyTransientFaultHandling = false) : base(connectionOptions)
+ bool applyTransientFaultHandling = false,
+ Func> accessTokenCallback = null) : base(connectionOptions)
{
#if DEBUG
@@ -487,6 +490,8 @@ internal SqlInternalConnectionTds(
_accessTokenInBytes = System.Text.Encoding.Unicode.GetBytes(accessToken);
}
+ _accessTokenCallback = accessTokenCallback;
+
_activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper();
_sqlAuthenticationProviderManager = SqlAuthenticationProviderManager.Instance;
@@ -1591,7 +1596,8 @@ private void Login(ServerInfo server, TimeoutTimer timeout, string newPassword,
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryMSI
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryDefault
// Since AD Integrated may be acting like Windows integrated, additionally check _fedAuthRequired
- || (ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated && _fedAuthRequired))
+ || (ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated && _fedAuthRequired)
+ || _accessTokenCallback != null)
{
requestedFeatures |= TdsEnums.FeatureExtension.FedAuth;
_federatedAuthenticationInfoRequested = true;
@@ -2578,6 +2584,7 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
{
Debug.Assert((ConnectionOptions._hasUserIdKeyword && ConnectionOptions._hasPasswordKeyword)
|| _credential != null
+ || _accessTokenCallback != null
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity
|| ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryMSI
@@ -2772,13 +2779,13 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
int numberOfAttempts = 0;
// Object that will be returned to the caller, containing all required data about the token.
- SqlFedAuthToken fedAuthToken = new SqlFedAuthToken();
+ _fedAuthToken = new SqlFedAuthToken();
// Username to use in error messages.
string username = null;
var authProvider = _sqlAuthenticationProviderManager.GetProvider(ConnectionOptions.Authentication);
- if (authProvider == null)
+ if (authProvider == null && _accessTokenCallback == null)
throw SQL.CannotFindAuthProvider(ConnectionOptions.Authentication.ToString());
// retry getting access token once if MsalException.error_code is unknown_error.
@@ -2802,14 +2809,14 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
username = TdsEnums.NTAUTHORITYANONYMOUSLOGON;
if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying)
{
- fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
+ _fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
}
else
{
// We use Task.Run here in all places to execute task synchronously in the same context.
// Fixes block-over-async deadlock possibilities https://github.com/dotnet/SqlClient/issues/1209
- fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken();
- _activeDirectoryAuthTimeoutRetryHelper.CachedToken = fedAuthToken;
+ _fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken();
+ _activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
break;
case SqlAuthenticationMethod.ActiveDirectoryInteractive:
@@ -2819,20 +2826,20 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
case SqlAuthenticationMethod.ActiveDirectoryDefault:
if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying)
{
- fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
+ _fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
}
else
{
authParamsBuilder.WithUserId(ConnectionOptions.UserID);
- fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken();
- _activeDirectoryAuthTimeoutRetryHelper.CachedToken = fedAuthToken;
+ _fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken();
+ _activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
break;
case SqlAuthenticationMethod.ActiveDirectoryPassword:
case SqlAuthenticationMethod.ActiveDirectoryServicePrincipal:
if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying)
{
- fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
+ _fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
}
else
{
@@ -2840,22 +2847,53 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
{
username = _credential.UserId;
authParamsBuilder.WithUserId(username).WithPassword(_credential.Password);
- fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken();
+ _fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken();
}
else
{
username = ConnectionOptions.UserID;
authParamsBuilder.WithUserId(username).WithPassword(ConnectionOptions.Password);
- fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken();
+ _fedAuthToken = Task.Run(async () => await authProvider.AcquireTokenAsync(authParamsBuilder)).GetAwaiter().GetResult().ToSqlFedAuthToken();
}
- _activeDirectoryAuthTimeoutRetryHelper.CachedToken = fedAuthToken;
+ _activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
break;
default:
- throw SQL.UnsupportedAuthenticationSpecified(ConnectionOptions.Authentication);
+ if (_accessTokenCallback == null)
+ {
+ throw SQL.UnsupportedAuthenticationSpecified(ConnectionOptions.Authentication);
+ }
+
+ if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying)
+ {
+ _fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
+ }
+ else
+ {
+ if (_credential != null)
+ {
+ username = _credential.UserId;
+ authParamsBuilder.WithUserId(username).WithPassword(_credential.Password);
+ }
+ else
+ {
+ authParamsBuilder.WithUserId(ConnectionOptions.UserID);
+ authParamsBuilder.WithPassword(ConnectionOptions.Password);
+ }
+ SqlAuthenticationParameters parameters = authParamsBuilder;
+ CancellationTokenSource cts = new();
+ // Use Connection timeout value to cancel token acquire request after certain period of time.(int)
+ if (_timeout.MillisecondsRemaining < Int32.MaxValue)
+ {
+ cts.CancelAfter((int)_timeout.MillisecondsRemaining);
+ }
+ _fedAuthToken = Task.Run(async () => await _accessTokenCallback(parameters, cts.Token)).GetAwaiter().GetResult().ToSqlFedAuthToken();
+ _activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
+ }
+ break;
}
- Debug.Assert(fedAuthToken.accessToken != null, "AccessToken should not be null.");
+ Debug.Assert(_fedAuthToken.accessToken != null, "AccessToken should not be null.");
#if DEBUG
if (_forceMsalRetry)
{
@@ -2923,17 +2961,17 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
}
}
- Debug.Assert(fedAuthToken != null, "fedAuthToken should not be null.");
- Debug.Assert(fedAuthToken.accessToken != null && fedAuthToken.accessToken.Length > 0, "fedAuthToken.accessToken should not be null or empty.");
+ Debug.Assert(_fedAuthToken != null, "fedAuthToken should not be null.");
+ Debug.Assert(_fedAuthToken.accessToken != null && _fedAuthToken.accessToken.Length > 0, "fedAuthToken.accessToken should not be null or empty.");
// Store the newly generated token in _newDbConnectionPoolAuthenticationContext, only if using pooling.
if (_dbConnectionPool != null)
{
- DateTime expirationTime = DateTime.FromFileTimeUtc(fedAuthToken.expirationFileTime);
- _newDbConnectionPoolAuthenticationContext = new DbConnectionPoolAuthenticationContext(fedAuthToken.accessToken, expirationTime);
+ DateTime expirationTime = DateTime.FromFileTimeUtc(_fedAuthToken.expirationFileTime);
+ _newDbConnectionPoolAuthenticationContext = new DbConnectionPoolAuthenticationContext(_fedAuthToken.accessToken, expirationTime);
}
SqlClientEventSource.Log.TryTraceEvent(" {0}, Finished generating federated authentication token.", ObjectID);
- return fedAuthToken;
+ return _fedAuthToken;
}
internal void OnFeatureExtAck(int featureId, byte[] data)
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs
index c407e1c6e9..16b2431a8a 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs
@@ -1487,7 +1487,10 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake(
}
// Validate Certificate if Trust Server Certificate=false and Encryption forced (EncryptionOptions.ON) from Server.
- bool shouldValidateServerCert = (_encryptionOption == EncryptionOptions.ON && !trustServerCert) || ((authType != SqlAuthenticationMethod.NotSpecified || _connHandler._accessTokenInBytes != null) && !trustServerCert);
+ bool shouldValidateServerCert = (_encryptionOption == EncryptionOptions.ON && !trustServerCert) ||
+ ((authType != SqlAuthenticationMethod.NotSpecified || (_connHandler._accessTokenInBytes != null ||
+ _connHandler._accessTokenCallback != null))
+ && !trustServerCert);
UInt32 info = (shouldValidateServerCert ? TdsEnums.SNI_SSL_VALIDATE_CERTIFICATE : 0)
| (is2005OrLater && (_encryptionOption & EncryptionOptions.CLIENT_CERT) == 0 ? TdsEnums.SNI_SSL_USE_SCHANNEL_CACHE : 0);
@@ -1551,7 +1554,7 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake(
// Or AccessToken is not null, mean token based authentication is used.
if ((_connHandler.ConnectionOptions != null
&& _connHandler.ConnectionOptions.Authentication != SqlAuthenticationMethod.NotSpecified)
- || _connHandler._accessTokenInBytes != null)
+ || _connHandler._accessTokenInBytes != null || _connHandler._accessTokenCallback != null)
{
fedAuthRequired = payload[payloadOffset] == 0x01 ? true : false;
}
@@ -8744,7 +8747,14 @@ internal int WriteFedAuthFeatureRequest(FederatedAuthenticationFeatureExtensionD
workflow = TdsEnums.MSALWORKFLOW_ACTIVEDIRECTORYDEFAULT;
break;
default:
- Debug.Assert(false, "Unrecognized Authentication type for fedauth MSAL request");
+ if (_connHandler._accessTokenCallback != null)
+ {
+ workflow = TdsEnums.MSALWORKFLOW_ACTIVEDIRECTORYTOKENCREDENTIAL;
+ }
+ else
+ {
+ Debug.Assert(false, "Unrecognized Authentication type for fedauth MSAL request");
+ }
break;
}
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs
index b30f54f8e5..5fd4a54382 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.Designer.cs
@@ -906,6 +906,15 @@ internal static string ADP_InvalidMixedUsageOfAccessTokenAndIntegratedSecurity {
}
}
+ ///
+ /// Looks up a localized string similar to Cannot set the AccessToken property if the AccessTokenCallback has been set..
+ ///
+ internal static string ADP_InvalidMixedUsageOfAccessTokenAndTokenCallback {
+ get {
+ return ResourceManager.GetString("ADP_InvalidMixedUsageOfAccessTokenAndTokenCallback", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Cannot set the AccessToken property if 'UserID', 'UID', 'Password', or 'PWD' has been specified in connection string..
///
@@ -915,6 +924,24 @@ internal static string ADP_InvalidMixedUsageOfAccessTokenAndUserIDPassword {
}
}
+ ///
+ /// Looks up a localized string similar to Cannot set the AccessTokenCallback property if the 'Integrated Security' connection string keyword has been set to 'true' or 'SSPI'..
+ ///
+ internal static string ADP_InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity {
+ get {
+ return ResourceManager.GetString("ADP_InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity", resourceCulture);
+ }
+ }
+
+ ///
+ /// Looks up a localized string similar to Cannot set the AccessTokenCallback property if 'Authentication=Active Directory Default' has been specified in the connection string..
+ ///
+ internal static string ADP_InvalidMixedUsageOfAuthenticationAndTokenCallback {
+ get {
+ return ResourceManager.GetString("ADP_InvalidMixedUsageOfAuthenticationAndTokenCallback", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Cannot set the Credential property if the AccessToken property is already set..
///
@@ -5559,6 +5586,15 @@ internal static string DbConnectionString_ApplicationName {
}
}
+ ///
+ /// Looks up a localized string similar to When true, enables usage of the Asynchronous functionality in the .NET Framework Data Provider..
+ ///
+ internal static string DbConnectionString_AsynchronousProcessing {
+ get {
+ return ResourceManager.GetString("DbConnectionString_AsynchronousProcessing", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to The name of the primary file, including the full path name, of an attachable database..
///
@@ -8910,6 +8946,15 @@ internal static string SQL_ArgumentLengthMismatch {
}
}
+ ///
+ /// Looks up a localized string similar to This command requires an asynchronous connection. Set "Asynchronous Processing=true" in the connection string..
+ ///
+ internal static string SQL_AsyncConnectionRequired {
+ get {
+ return ResourceManager.GetString("SQL_AsyncConnectionRequired", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to The asynchronous operation has already completed..
///
@@ -10593,6 +10638,15 @@ internal static string SqlConnection_AccessToken {
}
}
+ ///
+ /// Looks up a localized string similar to State of connection, synchronous or asynchronous. 'Asynchronous Processing=x' in the connection string..
+ ///
+ internal static string SqlConnection_Asynchronous {
+ get {
+ return ResourceManager.GetString("SqlConnection_Asynchronous", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to A guid to represent the physical connection..
///
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx
index 85627eae31..2b2374e490 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx
@@ -4635,4 +4635,13 @@
The '{0}' platform is not supported when targeting .NET Framework.
-
+
+ Cannot set the AccessToken property if the AccessTokenCallback has been set.
+
+
+ Cannot set the AccessTokenCallback property if the 'Integrated Security' connection string keyword has been set to 'true' or 'SSPI'.
+
+
+ Cannot set the AccessTokenCallback property if 'Authentication=Active Directory Default' has been specified in the connection string.
+
+
\ No newline at end of file
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.Windows.cs
index 9b9c0f3341..abeadf86b7 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.Windows.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.Windows.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System;
using System.Runtime.InteropServices;
using System.Runtime.Versioning;
using System.Security;
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs
index 8f48131f8d..08e46ed284 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs
@@ -840,7 +840,8 @@ private static string ConnectionStateMsg(ConnectionState state)
{ // MDAC 82165, if the ConnectionState enum to msg the localization looks weird
return state switch
{
- (ConnectionState.Closed) or (ConnectionState.Connecting | ConnectionState.Broken) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Closed),
+ (ConnectionState.Closed) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Closed),
+ (ConnectionState.Connecting | ConnectionState.Broken) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Closed),
(ConnectionState.Connecting) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Connecting),
(ConnectionState.Open) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_Open),
(ConnectionState.Open | ConnectionState.Executing) => StringsHelper.GetString(Strings.ADP_ConnectionStateMsg_OpenExecuting),
@@ -1267,7 +1268,16 @@ static internal InvalidOperationException InvalidMixedUsageOfAccessTokenAndAuthe
static internal Exception InvalidMixedUsageOfCredentialAndAccessToken()
=> InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfCredentialAndAccessToken));
-#endregion
+
+ static internal Exception InvalidMixedUsageOfAccessTokenAndTokenCallback()
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenAndTokenCallback));
+
+ internal static Exception InvalidMixedUsageOfAccessTokenCallbackAndAuthentication()
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAuthenticationAndTokenCallback));
+
+ internal static Exception InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity()
+ => InvalidOperation(StringsHelper.GetString(Strings.ADP_InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity));
+ #endregion
internal static bool IsEmpty(string str) => string.IsNullOrEmpty(str);
internal static readonly IntPtr s_ptrZero = IntPtr.Zero;
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionPoolKey.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionPoolKey.cs
index c35ce6f08e..1eed6a229d 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionPoolKey.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionPoolKey.cs
@@ -2,7 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System;
using System.Diagnostics;
+using System.Threading;
+using System.Threading.Tasks;
using Microsoft.Data.Common;
namespace Microsoft.Data.SqlClient
@@ -14,9 +17,11 @@ internal class SqlConnectionPoolKey : DbConnectionPoolKey
private int _hashValue;
private readonly SqlCredential _credential;
private readonly string _accessToken;
+ private Func> _accessTokenCallback;
internal SqlCredential Credential => _credential;
internal string AccessToken => _accessToken;
+ internal Func> AccessTokenCallback => _accessTokenCallback;
internal override string ConnectionString
{
@@ -48,11 +53,13 @@ internal SqlConnectionPoolKey(string connectionString,
string accessToken,
ServerCertificateValidationCallback serverCertificateValidationCallback,
ClientCertificateRetrievalCallback clientCertificateRetrievalCallback,
- SqlClientOriginalNetworkAddressInfo originalNetworkAddressInfo) : base(connectionString)
+ SqlClientOriginalNetworkAddressInfo originalNetworkAddressInfo,
+ Func> accessTokenCallback = null) : base(connectionString)
{
- Debug.Assert(_credential == null || _accessToken == null, "Credential and AccessToken can't have the value at the same time.");
+ Debug.Assert(_credential == null || _accessToken == null || accessTokenCallback == null, "Credential, AccessToken, and Callback can't have a value at the same time.");
_credential = credential;
_accessToken = accessToken;
+ _accessTokenCallback = accessTokenCallback;
_serverCertificateValidationCallback = serverCertificateValidationCallback;
_clientCertificateRetrievalCallback = clientCertificateRetrievalCallback;
_originalNetworkAddressInfo = originalNetworkAddressInfo;
@@ -61,11 +68,12 @@ internal SqlConnectionPoolKey(string connectionString,
#endregion
#else
#region NET Core
- internal SqlConnectionPoolKey(string connectionString, SqlCredential credential, string accessToken) : base(connectionString)
+ internal SqlConnectionPoolKey(string connectionString, SqlCredential credential, string accessToken, Func> accessTokenCallback) : base(connectionString)
{
- Debug.Assert(_credential == null || _accessToken == null, "Credential and AccessToken can't have the value at the same time.");
+ Debug.Assert(credential == null || accessToken == null || accessTokenCallback == null, "Credential, AccessToken, and Callback can't have a value at the same time.");
_credential = credential;
_accessToken = accessToken;
+ _accessTokenCallback = accessTokenCallback;
CalculateHashCode();
}
#endregion
@@ -75,6 +83,7 @@ private SqlConnectionPoolKey(SqlConnectionPoolKey key) : base(key)
{
_credential = key.Credential;
_accessToken = key.AccessToken;
+ _accessTokenCallback = key._accessTokenCallback;
#if NETFRAMEWORK
_serverCertificateValidationCallback = key._serverCertificateValidationCallback;
_clientCertificateRetrievalCallback = key._clientCertificateRetrievalCallback;
@@ -92,6 +101,7 @@ public override bool Equals(object obj)
return (obj is SqlConnectionPoolKey key
&& _credential == key._credential
&& ConnectionString == key.ConnectionString
+ && _accessTokenCallback == key._accessTokenCallback
&& string.CompareOrdinal(_accessToken, key._accessToken) == 0
#if NETFRAMEWORK
&& _serverCertificateValidationCallback == key._serverCertificateValidationCallback
@@ -124,6 +134,13 @@ private void CalculateHashCode()
_hashValue = _hashValue * 17 + _accessToken.GetHashCode();
}
}
+ else if (_accessTokenCallback != null)
+ {
+ unchecked
+ {
+ _hashValue = _hashValue * 17 + _accessTokenCallback.GetHashCode();
+ }
+ }
#if NETFRAMEWORK
if (_originalNetworkAddressInfo != null)
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs
index 81c2ed1570..8a8cb3772d 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs
@@ -272,6 +272,7 @@ public enum FedAuthLibrary : byte
public const byte MSALWORKFLOW_ACTIVEDIRECTORYDEVICECODEFLOW = 0x03; // Using the Interactive byte as that is the closest we have
public const byte MSALWORKFLOW_ACTIVEDIRECTORYMANAGEDIDENTITY = 0x03; // Using the Interactive byte as that's supported for Identity based authentication
public const byte MSALWORKFLOW_ACTIVEDIRECTORYDEFAULT = 0x03; // Using the Interactive byte as that is the closest we have to non-password based authentication modes
+ public const byte MSALWORKFLOW_ACTIVEDIRECTORYTOKENCREDENTIAL = 0x03; // Using the Interactive byte as that is the closest we have to non-password based authentication modes
public enum ActiveDirectoryWorkflow : byte
{
diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs
index e61bdfe4ac..03e6e7b5c7 100644
--- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs
+++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs
@@ -7,6 +7,7 @@
using System.Data.Common;
using System.Reflection;
using System.Security;
+using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlServer.TDS.Servers;
using Xunit;
@@ -264,5 +265,86 @@ public void ConnectionTestValidCredentialCombination()
Assert.Equal(sqlCredential, conn.Credential);
}
+
+ [Fact]
+ public void ConnectionTestAccessTokenCallbackCombinations()
+ {
+ var cleartextCredsConnStr = "User=test;Password=test;";
+ var sspiConnStr = "Integrated Security=true;";
+ var authConnStr = "Authentication=ActiveDirectoryPassword";
+ var testPassword = new SecureString();
+ testPassword.MakeReadOnly();
+ var sqlCredential = new SqlCredential(string.Empty, testPassword);
+ Func> callback = (ctx, token) =>
+ Task.FromResult(new SqlAuthenticationToken("invalid", DateTimeOffset.MaxValue));
+
+ // Successes
+ using (var conn = new SqlConnection(cleartextCredsConnStr))
+ {
+ conn.AccessTokenCallback = callback;
+ conn.AccessTokenCallback = null;
+ }
+
+ using (var conn = new SqlConnection(string.Empty, sqlCredential))
+ {
+ conn.AccessTokenCallback = null;
+ conn.AccessTokenCallback = callback;
+ }
+
+ using (var conn = new SqlConnection()
+ {
+ AccessTokenCallback = callback
+ })
+ {
+ conn.Credential = sqlCredential;
+ }
+
+ using (var conn = new SqlConnection()
+ {
+ AccessTokenCallback = callback
+ })
+ {
+ conn.ConnectionString = cleartextCredsConnStr;
+ }
+
+ //Failures
+ using (var conn = new SqlConnection(sspiConnStr))
+ {
+ Assert.Throws(() =>
+ {
+ conn.AccessTokenCallback = callback;
+ });
+ }
+
+ using (var conn = new SqlConnection(authConnStr))
+ {
+ Assert.Throws(() =>
+ {
+ conn.AccessTokenCallback = callback;
+ });
+ }
+
+ using (var conn = new SqlConnection()
+ {
+ AccessTokenCallback = callback
+ })
+ {
+ Assert.Throws(() =>
+ {
+ conn.ConnectionString = sspiConnStr;
+ });
+ }
+
+ using (var conn = new SqlConnection()
+ {
+ AccessTokenCallback = callback
+ })
+ {
+ Assert.Throws(() =>
+ {
+ conn.ConnectionString = authConnStr;
+ });
+ }
+ }
}
}
diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs
index 13f607583d..598fd1c8d8 100644
--- a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs
+++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs
@@ -203,7 +203,7 @@ public static IEnumerable GetConnectionStrings(bool withEnclave)
private static string GenerateAccessToken(string authorityURL, string aADAuthUserID, string aADAuthPassword)
{
- return AcquireTokenAsync(authorityURL, aADAuthUserID, aADAuthPassword).Result;
+ return AcquireTokenAsync(authorityURL, aADAuthUserID, aADAuthPassword).GetAwaiter().GetResult();
}
private static Task AcquireTokenAsync(string authorityURL, string userID, string password) => Task.Run(() =>
diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs
index 70cae84d73..52e5705140 100644
--- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs
+++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs
@@ -7,6 +7,8 @@
using System.Security;
using System.Threading;
using System.Threading.Tasks;
+using Azure.Core;
+using Azure.Identity;
using Microsoft.Identity.Client;
using Xunit;
@@ -551,6 +553,77 @@ public static void ActiveDirectoryDefaultWithPasswordMustFail()
Assert.Contains(expectedMessage, e.Message);
}
+ [ConditionalFact(nameof(IsAADConnStringsSetup))]
+ public static void ActiveDirectoryDefaultWithAccessTokenCallbackMustFail()
+ {
+ // connection fails with expected error message.
+ string[] credKeys = { "Authentication", "User ID", "Password", "UID", "PWD" };
+ string connStrWithNoCred = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.AADPasswordConnectionString, credKeys) +
+ "Authentication=ActiveDirectoryDefault";
+ InvalidOperationException e = Assert.Throws(() =>
+ {
+ using (SqlConnection conn = new SqlConnection(connStrWithNoCred))
+ {
+ conn.AccessTokenCallback = (ctx, token) =>
+ Task.FromResult(new SqlAuthenticationToken("my token", DateTimeOffset.MaxValue));
+ conn.Open();
+
+ Assert.NotEqual(System.Data.ConnectionState.Open, conn.State);
+ }
+ });
+
+ string expectedMessage = "Cannot set the AccessTokenCallback property if 'Authentication=Active Directory Default' has been specified in the connection string.";
+ Assert.Contains(expectedMessage, e.Message);
+ }
+
+ [ConditionalFact(nameof(IsAADConnStringsSetup))]
+ public static void AccessTokenCallbackMustOpenPassAndChangePropertyFail()
+ {
+ string[] credKeys = { "Authentication", "User ID", "Password", "UID", "PWD" };
+ string connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.AADPasswordConnectionString, credKeys);
+ var cred = new DefaultAzureCredential();
+ const string defaultScopeSuffix = "/.default";
+ using (SqlConnection conn = new SqlConnection(connStr))
+ {
+ conn.AccessTokenCallback = (ctx, cancellationToken) =>
+ {
+ string scope = ctx.Resource.EndsWith(defaultScopeSuffix) ? ctx.Resource : ctx.Resource + defaultScopeSuffix;
+ AccessToken token = cred.GetToken(new TokenRequestContext(new[] { scope }), cancellationToken);
+ return Task.FromResult(new SqlAuthenticationToken(token.Token, token.ExpiresOn));
+ };
+ conn.Open();
+ Assert.Equal(System.Data.ConnectionState.Open, conn.State);
+
+ InvalidOperationException ex = Assert.Throws(() => conn.AccessTokenCallback = null);
+ string expectedMessage = "Not allowed to change the 'AccessTokenCallback' property. The connection's current state is open.";
+ Assert.Contains(expectedMessage, ex.Message);
+ }
+ }
+
+ [ConditionalFact(nameof(IsAADConnStringsSetup))]
+ public static void AccessTokenCallbackReceivesUsernameAndPassword()
+ {
+ var userId = "someuser";
+ var pwd = "somepassword";
+ string[] credKeys = { "Authentication", "User ID", "Password", "UID", "PWD" };
+ string connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.AADPasswordConnectionString, credKeys) +
+ $"User ID={userId}; Password={pwd}";
+ var cred = new DefaultAzureCredential();
+ const string defaultScopeSuffix = "/.default";
+ using (SqlConnection conn = new SqlConnection(connStr))
+ {
+ conn.AccessTokenCallback = (parms, cancellationToken) =>
+ {
+ Assert.Equal(userId, parms.UserId);
+ Assert.Equal(pwd, parms.Password);
+ string scope = parms.Resource.EndsWith(defaultScopeSuffix) ? parms.Resource : parms.Resource + defaultScopeSuffix;
+ AccessToken token = cred.GetToken(new TokenRequestContext(new[] { scope }), cancellationToken);
+ return Task.FromResult(new SqlAuthenticationToken(token.Token, token.ExpiresOn));
+ };
+ conn.Open();
+ }
+ }
+
[ConditionalFact(nameof(IsAADConnStringsSetup))]
public static void ActiveDirectoryDefaultMustPass()
{