Skip to content

Commit

Permalink
support get token by resource
Browse files Browse the repository at this point in the history
  • Loading branch information
suwatch committed Oct 3, 2017
1 parent 3c8b5d4 commit 79a2cad
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 29 deletions.
59 changes: 46 additions & 13 deletions AADClient.Console/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ static int Main(string[] args)
var uri = EnsureAbsoluteUri(path, persistentAuthHelper);

var subscriptionId = GetTenantOrSubscription(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId).Result;
var resource = GetResource(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId, resource).Result;
return HttpInvoke(uri, cacheInfo, "get", Utils.GetDefaultVerbose(), null, headers).Result;
}
else if (String.Equals(verb, "get-tenant", StringComparison.OrdinalIgnoreCase))
Expand All @@ -103,7 +104,8 @@ static int Main(string[] args)
var uri = EnsureAbsoluteUri(path, persistentAuthHelper);

var subscriptionId = GetTenantOrSubscription(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId).Result;
var resource = GetResource(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId, resource).Result;
return HttpInvoke(uri, cacheInfo, "get", Utils.GetDefaultVerbose(), null, headers).Result;
}
else if (String.Equals(verb, "get-apps", StringComparison.OrdinalIgnoreCase))
Expand All @@ -116,7 +118,8 @@ static int Main(string[] args)
var uri = EnsureAbsoluteUri(path, persistentAuthHelper);

var subscriptionId = GetTenantOrSubscription(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId).Result;
var resource = GetResource(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId, resource).Result;
return HttpInvoke(uri, cacheInfo, "get", Utils.GetDefaultVerbose(), null, headers).Result;
}
// https://azure.microsoft.com/en-us/documentation/articles/resource-group-authenticate-service-principal/
Expand All @@ -137,7 +140,8 @@ static int Main(string[] args)
var uri = EnsureAbsoluteUri(path, persistentAuthHelper);

var subscriptionId = GetTenantOrSubscription(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId).Result;
var resource = GetResource(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId, resource).Result;
return HttpInvoke(uri, cacheInfo, "get", Utils.GetDefaultVerbose(), null, headers).Result;
}
else if (String.Equals(verb, "get-app", StringComparison.OrdinalIgnoreCase))
Expand All @@ -154,9 +158,9 @@ static int Main(string[] args)
: String.Format("/{0}/applications?$filter=displayName eq '{1}'&api-version=1.6", tenant, app);

var uri = EnsureAbsoluteUri(path, persistentAuthHelper);

var resource = GetResource(uri);
var subscriptionId = GetTenantOrSubscription(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId).Result;
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId, resource).Result;
return HttpInvoke(uri, cacheInfo, "get", Utils.GetDefaultVerbose(), null, headers).Result;
}
// https://msdn.microsoft.com/library/azure/ad/graph/api/entity-and-complex-type-reference#serviceprincipalentity
Expand All @@ -173,7 +177,8 @@ static int Main(string[] args)
var uri = EnsureAbsoluteUri(path, persistentAuthHelper);

var subscriptionId = GetTenantOrSubscription(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId).Result;
var resource = GetResource(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId, resource).Result;
return HttpInvoke(uri, cacheInfo, "get", Utils.GetDefaultVerbose(), null, headers).Result;
}
else if (String.Equals(verb, "add-cred", StringComparison.OrdinalIgnoreCase))
Expand Down Expand Up @@ -223,7 +228,8 @@ static int Main(string[] args)
var uri = EnsureAbsoluteUri(path, persistentAuthHelper);

var subscriptionId = GetTenantOrSubscription(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId).Result;
var resource = GetResource(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId, resource).Result;

return HttpInvoke(uri, cacheInfo, "patch", Utils.GetDefaultVerbose(), content, headers).Result;
}
Expand All @@ -245,7 +251,8 @@ static int Main(string[] args)
var uri = EnsureAbsoluteUri(path, persistentAuthHelper);

var subscriptionId = GetTenantOrSubscription(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId).Result;
var resource = GetResource(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId, resource).Result;

return HttpInvoke(uri, cacheInfo, "patch", Utils.GetDefaultVerbose(), content, headers).Result;
}
Expand All @@ -259,7 +266,8 @@ static int Main(string[] args)
var uri = EnsureAbsoluteUri(path, persistentAuthHelper);

var subscriptionId = GetTenantOrSubscription(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId).Result;
var resource = GetResource(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId, resource).Result;
return HttpInvoke(uri, cacheInfo, "get", Utils.GetDefaultVerbose(), null, headers).Result;
}
else if (String.Equals(verb, "get-user", StringComparison.OrdinalIgnoreCase))
Expand All @@ -277,7 +285,8 @@ static int Main(string[] args)
var uri = EnsureAbsoluteUri(path, persistentAuthHelper);

var subscriptionId = GetTenantOrSubscription(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId).Result;
var resource = GetResource(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId, resource).Result;
return HttpInvoke(uri, cacheInfo, "get", Utils.GetDefaultVerbose(), null, headers).Result;
}
else if (String.Equals(verb, "get-groups", StringComparison.OrdinalIgnoreCase))
Expand All @@ -292,7 +301,8 @@ static int Main(string[] args)
var uri = EnsureAbsoluteUri(path, persistentAuthHelper);

var subscriptionId = GetTenantOrSubscription(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId).Result;
var resource = GetResource(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId, resource).Result;
var content = new StringContent("{\"securityEnabledOnly\": false}", Encoding.UTF8, "application/json");
return HttpInvoke(uri, cacheInfo, "post", Utils.GetDefaultVerbose(), content, headers).Result;
}
Expand Down Expand Up @@ -323,7 +333,8 @@ static async Task<JObject> GetAppObject(PersistentAuthHelper persistentAuthHelpe
var uri = EnsureAbsoluteUri(path, persistentAuthHelper);

var subscriptionId = GetTenantOrSubscription(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId).Result;
var resource = GetResource(uri);
TokenCacheInfo cacheInfo = persistentAuthHelper.GetToken(subscriptionId, resource).Result;

var json = await Utils.HttpGet(uri, cacheInfo);
var apps = json.Value<JArray>("value");
Expand Down Expand Up @@ -752,5 +763,27 @@ static AzureEnvironments GetAzureEnvironments(Uri uri, PersistentAuthHelper pers

return AzureEnvironments.Prod;
}

static string GetResource(Uri uri, AzureEnvironments env = AzureEnvironments.Prod)
{
try
{
if (Utils.IsGraphApi(uri))
{
return Constants.AADGraphUrls[(int)env];
}

if (Utils.IsKeyVault(uri))
{
return Constants.KeyVaultResources[(int)env];
}

return Constants.CSMResources[(int)env];
}
catch (Exception ex)
{
throw new InvalidOperationException(String.Format("Invalid url {0}!", uri), ex);
}
}
}
}
114 changes: 102 additions & 12 deletions ARMClient.Authentication/AADAuthentication/BaseAuthHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Security.Cryptography.X509Certificates;
using System.Text;
Expand Down Expand Up @@ -128,6 +129,31 @@ public async Task AzLogin()
this.TenantStorage.SaveCache(tenantCache);
}

public async Task<TokenCacheInfo> GetTokenByResource(string resource)
{
var cacheInfo = await GetRecentToken(resource);
if (cacheInfo != null)
{
return cacheInfo;
}

cacheInfo = await GetToken(null, null);
var tokenCache = TokenStorage.GetCache();
TokenCacheInfo found;
if (tokenCache.TryGetValue(cacheInfo.TenantId, resource, out found))
{
cacheInfo = found;
}
else
{
cacheInfo = await GetAuthorizationResult(tokenCache, tenantId: cacheInfo.TenantId, user: cacheInfo.DisplayableId, resource: resource);
this.TokenStorage.SaveCache(tokenCache);
}

this.TokenStorage.SaveRecentToken(cacheInfo, resource);
return cacheInfo;
}

private AzAccessToken[] GetAzLoginTokens()
{
var azCmd = Environment.ExpandEnvironmentVariables(@"%ProgramFiles(x86)%\Microsoft SDKs\Azure\CLI2\wbin\az.cmd");
Expand Down Expand Up @@ -220,11 +246,11 @@ private AzAccessToken[] GetAzLoginTokens()
}
}

public async Task<TokenCacheInfo> GetToken(string id)
public async Task<TokenCacheInfo> GetToken(string id, string resource)
{
try
{
return await GetTokenInternal(id);
return await GetTokenInternal(id, resource);
}
catch (AdalServiceException ex)
{
Expand All @@ -236,14 +262,14 @@ public async Task<TokenCacheInfo> GetToken(string id)

await AcquireTokens();

return await GetTokenInternal(id);
return await GetTokenInternal(id, resource);
}

private async Task<TokenCacheInfo> GetTokenInternal(string id)
private async Task<TokenCacheInfo> GetTokenInternal(string id, string resource)
{
if (String.IsNullOrEmpty(id))
{
return await GetRecentToken(Constants.CSMResources[(int)AzureEnvironments]);
return await GetRecentToken(resource ?? Constants.CSMResources[(int)AzureEnvironments]);
}

string tenantId = null;
Expand All @@ -266,12 +292,22 @@ private async Task<TokenCacheInfo> GetTokenInternal(string id)
}
}

// look up tenant by assuming it is subscription
if (String.IsNullOrEmpty(tenantId))
{
tenantId = await GetTenantIdFromSubscription(id, throwIfNotFound: true);
}

if (String.IsNullOrEmpty(tenantId))
{
return await GetRecentToken(Constants.CSMResources[(int)AzureEnvironments]);
}

var resource = id == tenantId ? Constants.AADGraphUrls[(int)AzureEnvironments] : Constants.CSMResources[(int)AzureEnvironments];
if (string.IsNullOrEmpty(resource))
{
resource = id == tenantId ? Constants.AADGraphUrls[(int)AzureEnvironments] : Constants.CSMResources[(int)AzureEnvironments];
}

var tokenCache = this.TokenStorage.GetCache();
TokenCacheInfo cacheInfo;
if (!tokenCache.TryGetValue(tenantId, resource, out cacheInfo))
Expand Down Expand Up @@ -300,6 +336,43 @@ private async Task<TokenCacheInfo> GetTokenInternal(string id)
return cacheInfo;
}

private async Task<string> GetTenantIdFromSubscription(string subscriptionId, bool throwIfNotFound = true)
{
using (var client = new HttpClient())
{
var serviceUrl = ARMClient.Authentication.Constants.CSMUrls[(int)AzureEnvironments];
string requestUri = String.Format("{0}/subscriptions/{1}?api-version=2014-04-01", serviceUrl.Trim('/'), subscriptionId);
using (var response = await client.GetAsync(requestUri))
{
if (response.StatusCode != HttpStatusCode.Unauthorized)
{
if (!throwIfNotFound && response.StatusCode == HttpStatusCode.NotFound)
{
return null;
}

throw new InvalidOperationException(String.Format("Expected Status {0} != {1} GET {2}", HttpStatusCode.Unauthorized, response.StatusCode, requestUri));
}

var header = response.Headers.WwwAuthenticate.SingleOrDefault();
if (header == null || String.IsNullOrEmpty(header.Parameter))
{
throw new InvalidOperationException(String.Format("Missing WWW-Authenticate response header GET {0}", requestUri));
}

// WWW-Authenticate: Bearer authorization_uri="https://login.windows.net/<tenantid>", error="invalid_token", error_description="The access token is missing or invalid."
var index = header.Parameter.IndexOf("authorization_uri=", StringComparison.OrdinalIgnoreCase);
if (index < 0)
{
throw new InvalidOperationException(String.Format("Invalid WWW-Authenticat response header {0} GET {1}", header.Parameter, requestUri));
}

var parts = header.Parameter.Substring(index).Split(new[] { '\"', '=' }, StringSplitOptions.RemoveEmptyEntries);
return new Uri(parts[1]).AbsolutePath.Trim('/');
}
}
}

public async Task<TokenCacheInfo> GetTokenBySpn(string tenantId, string appId, string appKey)
{
this.TokenStorage.ClearCache();
Expand Down Expand Up @@ -473,12 +546,29 @@ protected Task<TokenCacheInfo> GetAuthorizationResult(CustomTokenCache tokenCach
AuthenticationResult result = null;
if (!string.IsNullOrEmpty(user))
{
result = context.AcquireToken(
resource: resource,
clientId: Constants.AADClientId,
redirectUri: new Uri(Constants.AADRedirectUri),
promptBehavior: PromptBehavior.Never,
userId: new UserIdentifier(user, UserIdentifierType.OptionalDisplayableId));
try
{
result = context.AcquireToken(
resource: resource,
clientId: Constants.AADClientId,
redirectUri: new Uri(Constants.AADRedirectUri),
promptBehavior: PromptBehavior.Never,
userId: new UserIdentifier(user, UserIdentifierType.OptionalDisplayableId));
}
catch (AdalException adalEx)
{
if (adalEx.Message.IndexOf("user_interaction_required") < 0)
{
throw;
}

result = context.AcquireToken(
resource: resource,
clientId: Constants.AADClientId,
redirectUri: new Uri(Constants.AADRedirectUri),
promptBehavior: PromptBehavior.Auto,
userId: new UserIdentifier(user, UserIdentifierType.OptionalDisplayableId));
}
}
else
{
Expand Down
11 changes: 11 additions & 0 deletions ARMClient.Authentication/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ public static class Constants
"https://notsupport.com/"
};

public static string[] KeyVaultResources = new[]
{
"https://vault.azure.net",
"https://vault.azure.net",
"https://vault.azure.net",
"https://vault.azure.net",
"https://vault.azure.net",
"https://vault.azure.net",
"https://vault.azure.net"
};

public static string[] SCMSuffixes = new[]
{
".scm.chinacloudsites.cn",
Expand Down
2 changes: 1 addition & 1 deletion ARMClient.Authentication/IAuthHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public interface IAuthHelper
AzureEnvironments AzureEnvironments { get; set; }
Task AcquireTokens();
Task AzLogin();
Task<TokenCacheInfo> GetToken(string id);
Task<TokenCacheInfo> GetToken(string id, string resource);
Task<TokenCacheInfo> GetTokenBySpn(string tenantId, string appId, string appKey);
Task<TokenCacheInfo> GetTokenByUpn(string username, string password);
bool IsCacheValid();
Expand Down
6 changes: 6 additions & 0 deletions ARMClient.Authentication/Utilities/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -288,5 +288,11 @@ public static bool IsCSM(Uri uri)
var host = uri.Host;
return Constants.CSMUrls.Any(url => url.IndexOf(host, StringComparison.OrdinalIgnoreCase) > 0);
}

public static bool IsKeyVault(Uri uri)
{
var host = uri.Host;
return host.EndsWith(".vault.azure.net", StringComparison.OrdinalIgnoreCase);
}
}
}
Loading

0 comments on commit 79a2cad

Please sign in to comment.