Skip to content

Commit 8fad4a4

Browse files
authored
API | AccessTokenCallback support (#1260)
1 parent 2b31810 commit 8fad4a4

File tree

23 files changed

+653
-85
lines changed

23 files changed

+653
-85
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using System;
2+
using System.Data;
3+
// <Snippet1>
4+
using Microsoft.Data.SqlClient;
5+
using Azure.Identity;
6+
7+
class Program
8+
{
9+
static void Main()
10+
{
11+
OpenSqlConnection();
12+
Console.ReadLine();
13+
}
14+
15+
private static void OpenSqlConnection()
16+
{
17+
string connectionString = GetConnectionString();
18+
using (SqlConnection connection = new SqlConnection("Data Source=contoso.database.windows.net;Initial Catalog=AdventureWorks;")
19+
{
20+
AccessTokenCallback = async (authParams, cancellationToken) =>
21+
{
22+
var cred = new DefaultAzureCredential();
23+
string scope = authParams.Resource.EndsWith(s_defaultScopeSuffix) ? authParams.Resource : authParams.Resource + s_defaultScopeSuffix;
24+
var token = await cred.GetTokenAsync(new TokenRequestContext(new[] { scope }), cancellationToken);
25+
return new SqlAuthenticationToken(token.Token, token.ExpiresOn);
26+
}
27+
})
28+
{
29+
connection.Open();
30+
Console.WriteLine("ServerVersion: {0}", connection.ServerVersion);
31+
Console.WriteLine("State: {0}", connection.State);
32+
}
33+
}
34+
}
35+
// </Snippet1>

doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,22 @@ using (SqlConnection connection = new SqlConnection(connectionString))
129129
<value>The access token for the connection.</value>
130130
<remarks>To be added.</remarks>
131131
</AccessToken>
132+
<AccessTokenCallback>
133+
<summary>Gets or sets the access token callback for the connection.</summary>
134+
<value>
135+
The Func that takes a <see cref="SqlAuthenticationParameters" /> and <see cref="System.Threading.CancellationToken" /> and returns a <see cref="SqlAuthenticationToken" />.</value>
136+
<remarks>
137+
<format type="text/markdown"><![CDATA[
138+
139+
## Examples
140+
The following example demonstrates how to define and set an <xref:Microsoft.Data.SqlClient.AccessTokenCallback>.
141+
142+
[!code-csharp[SqlConnection_AccessTokenCallback Example#1](~/../sqlclient/doc/samples/SqlConnection_AccessTokenCallback.cs#1)]
143+
144+
]]></format>
145+
</remarks>
146+
<exception cref="T:System.InvalidOperationException">The AccessTokenCallback is combined with other conflicting authentication configurations.</exception>
147+
</AccessTokenCallback>
132148
<BeginDbTransaction>
133149
<param name="isolationLevel">To be added.</param>
134150
<summary>To be added.</summary>

src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
// NOTE: The current Microsoft.VSDesigner editor attributes are implemented for System.Data.SqlClient, and are not publicly available.
66
// New attributes that are designed to work with Microsoft.Data.SqlClient and are publicly documented should be included in future.
7+
78
[assembly: System.CLSCompliant(true)]
89
namespace Microsoft.Data
910
{
@@ -839,6 +840,8 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collect
839840
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/ClientConnectionId/*'/>
840841
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
841842
public System.Guid ClientConnectionId { get { throw null; } }
843+
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/AccessTokenCallback/*' />
844+
public System.Func<SqlAuthenticationParameters, System.Threading.CancellationToken, System.Threading.Tasks.Task<SqlAuthenticationToken>> AccessTokenCallback { get { throw null; } set { } }
842845

843846
///
844847
/// for internal test only

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ private static readonly Dictionary<string, SqlColumnEncryptionKeyStoreProvider>
8787
/// Instance-level list of custom key store providers. It can be set more than once by the user.
8888
private IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> _customColumnEncryptionKeyStoreProviders;
8989

90+
private Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;
91+
9092
internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
9193
_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0;
9294

@@ -272,7 +274,7 @@ internal static List<string> GetColumnEncryptionSystemKeyStoreProvidersNames()
272274
}
273275

274276
/// <summary>
275-
/// This function returns a list of the names of the custom providers currently registered. If the
277+
/// This function returns a list of the names of the custom providers currently registered. If the
276278
/// instance-level cache is not empty, that cache is used, else the global cache is used.
277279
/// </summary>
278280
/// <returns>Combined list of provider names</returns>
@@ -344,7 +346,7 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(IDictionary<st
344346
new(customProviders, StringComparer.OrdinalIgnoreCase);
345347

346348
// Set the dictionary to the ReadOnly dictionary.
347-
// This method can be called more than once. Re-registering a new collection will replace the
349+
// This method can be called more than once. Re-registering a new collection will replace the
348350
// old collection of providers.
349351
_customColumnEncryptionKeyStoreProviders = customColumnEncryptionKeyStoreProviders;
350352
}
@@ -584,7 +586,7 @@ public override string ConnectionString
584586
}
585587
set
586588
{
587-
if (_credential != null || _accessToken != null)
589+
if (_credential != null || _accessToken != null || _accessTokenCallback != null)
588590
{
589591
SqlConnectionString connectionOptions = new SqlConnectionString(value);
590592
if (_credential != null)
@@ -620,12 +622,18 @@ public override string ConnectionString
620622

621623
CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions);
622624
}
623-
else if (_accessToken != null)
625+
626+
if (_accessToken != null)
624627
{
625628
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(connectionOptions);
626629
}
630+
631+
if (_accessTokenCallback != null)
632+
{
633+
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions);
634+
}
627635
}
628-
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken));
636+
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback));
629637
_connectionString = value; // Change _connectionString value only after value is validated
630638
CacheConnectionStringProperties();
631639
}
@@ -685,11 +693,34 @@ public string AccessToken
685693
}
686694

687695
// Need to call ConnectionString_Set to do proper pool group check
688-
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value));
696+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null));
689697
_accessToken = value;
690698
}
691699
}
692700

701+
/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/AccessTokenCallback/*' />
702+
public Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> AccessTokenCallback
703+
{
704+
get { return _accessTokenCallback; }
705+
set
706+
{
707+
// If a connection is connecting or is ever opened, AccessToken callback cannot be set
708+
if (!InnerConnection.AllowSetConnectionString)
709+
{
710+
throw ADP.OpenConnectionPropertySet(nameof(AccessTokenCallback), InnerConnection.State);
711+
}
712+
713+
if (value != null)
714+
{
715+
// Check if the usage of AccessToken has any conflict with the keys used in connection string and credential
716+
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback((SqlConnectionString)ConnectionOptions);
717+
}
718+
719+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: value));
720+
_accessTokenCallback = value;
721+
}
722+
}
723+
693724
/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/Database/*' />
694725
[ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)]
695726
[ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)]
@@ -970,6 +1001,7 @@ public SqlCredential Credential
9701001
}
9711002

9721003
CheckAndThrowOnInvalidCombinationOfConnectionStringAndSqlCredential(connectionOptions);
1004+
9731005
if (_accessToken != null)
9741006
{
9751007
throw ADP.InvalidMixedUsageOfCredentialAndAccessToken();
@@ -979,7 +1011,7 @@ public SqlCredential Credential
9791011
_credential = value;
9801012

9811013
// Need to call ConnectionString_Set to do proper pool group check
982-
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken));
1014+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback));
9831015
}
9841016
}
9851017

@@ -1026,6 +1058,33 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(S
10261058
{
10271059
throw ADP.InvalidMixedUsageOfCredentialAndAccessToken();
10281060
}
1061+
1062+
if(_accessTokenCallback != null)
1063+
{
1064+
throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
1065+
}
1066+
}
1067+
1068+
// CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback: check if the usage of AccessTokenCallback has any conflict
1069+
// with the keys used in connection string and credential
1070+
// If there is any conflict, it throws InvalidOperationException
1071+
// This is to be used setter of ConnectionString and AccessTokenCallback properties
1072+
private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(SqlConnectionString connectionOptions)
1073+
{
1074+
if (UsesIntegratedSecurity(connectionOptions))
1075+
{
1076+
throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndIntegratedSecurity();
1077+
}
1078+
1079+
if (UsesAuthentication(connectionOptions))
1080+
{
1081+
throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndAuthentication();
1082+
}
1083+
1084+
if(_accessToken != null)
1085+
{
1086+
throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
1087+
}
10291088
}
10301089

10311090
/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/DbProviderFactory/*' />
@@ -2128,7 +2187,7 @@ public static void ChangePassword(string connectionString, string newPassword)
21282187
throw ADP.InvalidArgumentLength(nameof(newPassword), TdsEnums.MAXLEN_NEWPASSWORD);
21292188
}
21302189

2131-
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null);
2190+
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null);
21322191

21332192
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
21342193
if (connectionOptions.IntegratedSecurity)
@@ -2177,7 +2236,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent
21772236
throw ADP.InvalidArgumentLength(nameof(newSecurePassword), TdsEnums.MAXLEN_NEWPASSWORD);
21782237
}
21792238

2180-
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null);
2239+
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
21812240

21822241
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
21832242

@@ -2216,7 +2275,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString
22162275
if (con != null)
22172276
con.Dispose();
22182277
}
2219-
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null);
2278+
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
22202279

22212280
SqlConnectionFactory.SingletonInstance.ClearPool(key);
22222281
}

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
133133
opt = new SqlConnectionString(opt, instanceName, userInstance: false, setEnlistValue: null);
134134
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
135135
}
136-
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool);
136+
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool, key.AccessTokenCallback);
137137
}
138138

139139
protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous)

0 commit comments

Comments
 (0)