Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for validating the org_name claim [SDK-4414] #782

Merged
merged 5 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions Auth0/ClaimValidators.swift
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ struct IDTokenAuthTimeValidator: JWTValidator {
}
}

struct IDTokenOrgIdValidator: JWTValidator {
struct IDTokenOrgIDValidator: JWTValidator {
enum ValidationError: Auth0Error {
case missingOrgId
case mismatchedOrgId(actual: String, expected: String)
Expand All @@ -263,16 +263,45 @@ struct IDTokenOrgIdValidator: JWTValidator {
}
}

private let expectedOrganization: String
private let expectedOrgID: String

init(organization: String) {
self.expectedOrganization = organization
init(orgID: String) {
self.expectedOrgID = orgID
}

func validate(_ jwt: JWT) -> Auth0Error? {
guard let actualOrganization = jwt.claim(name: "org_id").string else { return ValidationError.missingOrgId }
guard actualOrganization == expectedOrganization else {
return ValidationError.mismatchedOrgId(actual: actualOrganization, expected: expectedOrganization)
guard let actualOrgID = jwt.claim(name: "org_id").string else { return ValidationError.missingOrgId }
guard actualOrgID == expectedOrgID else {
return ValidationError.mismatchedOrgId(actual: actualOrgID, expected: expectedOrgID)
}
return nil
}
}

struct IDTokenOrgNameValidator: JWTValidator {
enum ValidationError: Auth0Error {
case missingOrgName
case mismatchedOrgName(actual: String, expected: String)

var debugDescription: String {
switch self {
case .missingOrgName: return "Organization Name (org_name) claim must be a string present in the ID token"
case .mismatchedOrgName(let actual, let expected):
return "Organization Name (org_name) claim value mismatch in the ID token; expected (\(expected)), found (\(actual))"
}
}
}

private let expectedOrgName: String

init(orgName: String) {
self.expectedOrgName = orgName
}

func validate(_ jwt: JWT) -> Auth0Error? {
guard let actualOrgName = jwt.claim(name: "org_name").string else { return ValidationError.missingOrgName }
guard actualOrgName.caseInsensitiveCompare(expectedOrgName) == .orderedSame else {
return ValidationError.mismatchedOrgName(actual: actualOrgName, expected: expectedOrgName)
}
return nil
}
Expand Down
6 changes: 5 additions & 1 deletion Auth0/IDTokenValidator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ func validate(idToken: String,
claimValidators.append(IDTokenAuthTimeValidator(leeway: context.leeway, maxAge: maxAge))
}
if let organization = context.organization {
claimValidators.append(IDTokenOrgIdValidator(organization: organization))
if organization.starts(with: "org_") {
claimValidators.append(IDTokenOrgIDValidator(orgID: organization))
} else {
claimValidators.append(IDTokenOrgNameValidator(orgName: organization))
}
}
let validator = IDTokenValidator(signatureValidator: signatureValidator ?? IDTokenSignatureValidator(context: context),
claimsValidator: claimsValidator ?? IDTokenClaimsValidator(validators: claimValidators),
Expand Down
81 changes: 65 additions & 16 deletions Auth0Tests/ClaimValidatorsSpec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -414,47 +414,96 @@ class ClaimValidatorsSpec: IDTokenValidatorBaseSpec {

}

describe("organization validation") {
describe("organization id validation") {

var organizationValidator: IDTokenOrgIdValidator!
let expectedOrganization = "abc1234"
var orgIDValidator: IDTokenOrgIDValidator!
let expectedOrgID = "org_abc1234"

beforeEach {
organizationValidator = IDTokenOrgIdValidator(organization: expectedOrganization)
orgIDValidator = IDTokenOrgIDValidator(orgID: expectedOrgID)
}

context("missing org_id") {
it("should return nil if org_id is present") {
let jwt = generateJWT(organization: expectedOrganization)
let jwt = generateJWT(orgID: expectedOrgID)

expect(organizationValidator.validate(jwt)).to(beNil())
expect(orgIDValidator.validate(jwt)).to(beNil())
}

it("should return an error if org_id is missing") {
let jwt = generateJWT(organization: nil)
let expectedError = IDTokenOrgIdValidator.ValidationError.missingOrgId
let result = organizationValidator.validate(jwt)
let jwt = generateJWT(orgID: nil)
let expectedError = IDTokenOrgIDValidator.ValidationError.missingOrgId
let result = orgIDValidator.validate(jwt)

expect(result).to(matchError(expectedError))
expect(result?.localizedDescription).to(equal(expectedError.localizedDescription))
}
}

context("mismatched org_id") {
it("should return an error if org_id does not match the request organization") {
let organization = "xyz6789"
let jwt = generateJWT(organization: organization)
let expectedError = IDTokenOrgIdValidator.ValidationError.mismatchedOrgId(actual: organization,
expected: expectedOrganization)
let result = organizationValidator.validate(jwt)
it("should return an error if org_id does not match the request organization id") {
let orgID = "org_xyz6789"
let jwt = generateJWT(orgID: orgID)
let expectedError = IDTokenOrgIDValidator.ValidationError.mismatchedOrgId(actual: orgID,
expected: expectedOrgID)
let result = orgIDValidator.validate(jwt)

expect(result).to(matchError(expectedError))
expect(result?.localizedDescription).to(equal(expectedError.localizedDescription))
}
}

}


describe("organization name validation") {

var orgNameValidator: IDTokenOrgNameValidator!
let expectedOrgName = "abc1234"

beforeEach {
orgNameValidator = IDTokenOrgNameValidator(orgName: expectedOrgName)
}

context("missing org_name") {
it("should return nil if org_name is present") {
let jwt = generateJWT(orgName: expectedOrgName)

expect(orgNameValidator.validate(jwt)).to(beNil())
}

it("should return an error if org_name is missing") {
let jwt = generateJWT(orgName: nil)
let expectedError = IDTokenOrgNameValidator.ValidationError.missingOrgName
let result = orgNameValidator.validate(jwt)

expect(result).to(matchError(expectedError))
expect(result?.localizedDescription).to(equal(expectedError.localizedDescription))
}
}

context("mismatched org_name") {
it("should return an error if org_name does not match the request organization name") {
let orgName = "xyz6789"
let jwt = generateJWT(orgName: orgName)
let expectedError = IDTokenOrgNameValidator.ValidationError.mismatchedOrgName(actual: orgName,
expected: expectedOrgName)
let result = orgNameValidator.validate(jwt)

expect(result).to(matchError(expectedError))
expect(result?.localizedDescription).to(equal(expectedError.localizedDescription))
}
}

it("should perform a case insensitive compare") {
let orgName = "aBc1234"
let expectedOrgName = "AbC1234"
let jwt = generateJWT(orgName: orgName)
orgNameValidator = IDTokenOrgNameValidator(orgName: expectedOrgName)

expect(orgNameValidator.validate(jwt)).to(beNil())
}

}
}

}
20 changes: 13 additions & 7 deletions Auth0Tests/Generators.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ private func generateJWTPayload(iss: String?,
nonce: String?,
maxAge: Int?,
authTime: Date?,
organization: String?) -> String {
orgID: String?,
orgName: String?) -> String {
var bodyDict: [String: Any] = [:]

if let iss = iss {
Expand Down Expand Up @@ -84,10 +85,14 @@ private func generateJWTPayload(iss: String?,
bodyDict["nonce"] = nonce
}

if let organization = organization {
bodyDict["org_id"] = organization
if let orgID = orgID {
bodyDict["org_id"] = orgID
}


if let orgName = orgName {
bodyDict["org_name"] = orgName
}

return encodeJWTPart(from: bodyDict)
}

Expand All @@ -102,7 +107,8 @@ func generateJWT(alg: String = JWTAlgorithm.rs256.rawValue,
nonce: String? = "a1b2c3d4e5",
maxAge: Int? = nil,
authTime: Date? = nil,
organization: String? = nil,
orgID: String? = nil,
orgName: String? = nil,
signature: String? = nil) -> JWT {
let header = generateJWTHeader(alg: alg, kid: kid)
let body = generateJWTPayload(iss: iss,
Expand All @@ -114,7 +120,8 @@ func generateJWT(alg: String = JWTAlgorithm.rs256.rawValue,
nonce: nonce,
maxAge: maxAge,
authTime: authTime,
organization: organization)
orgID: orgID,
orgName: orgName)

let signableParts = "\(header).\(body)"
var signaturePart = ""
Expand All @@ -128,7 +135,6 @@ func generateJWT(alg: String = JWTAlgorithm.rs256.rawValue,
signaturePart = (data! as Data).a0_encodeBase64URLSafe()!
}


return try! decode(jwt: "\(signableParts).\(signaturePart)")
}

Expand Down
73 changes: 69 additions & 4 deletions Auth0Tests/IDTokenValidatorSpec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,16 @@ class IDTokenValidatorSpec: IDTokenValidatorBaseSpec {
}
}

it("should validate a token with an organization") {
let organization = "abc1234"
let jwt = generateJWT(aud: aud, azp: nil, nonce: nil, maxAge: nil, authTime: nil, organization: organization)
it("should validate a token with an organization ID") {
let orgID = "org_abc1234"
let jwt = generateJWT(aud: aud, azp: nil, nonce: nil, maxAge: nil, authTime: nil, orgID: orgID)
let context = IDTokenValidatorContext(issuer: validatorContext.issuer,
audience: aud[0],
jwksRequest: validatorContext.jwksRequest,
leeway: validatorContext.leeway,
maxAge: nil,
nonce: nil,
organization: organization)
organization: orgID)

await waitUntil { done in
validate(idToken: jwt.string,
Expand All @@ -236,6 +236,71 @@ class IDTokenValidatorSpec: IDTokenValidatorBaseSpec {
}
}
}

it("should validate a token with an organization name") {
let orgName = "abc1234"
let jwt = generateJWT(aud: aud, azp: nil, nonce: nil, maxAge: nil, authTime: nil, orgName: orgName)
let context = IDTokenValidatorContext(issuer: validatorContext.issuer,
audience: aud[0],
jwksRequest: validatorContext.jwksRequest,
leeway: validatorContext.leeway,
maxAge: nil,
nonce: nil,
organization: orgName)

await waitUntil { done in
validate(idToken: jwt.string,
with: context,
signatureValidator: mockSignatureValidator) { error in
expect(error).to(beNil())
done()
}
}
}

it("should expect an organization ID instead of an organization name") {
let orgID = "org_abc1234"
let jwt = generateJWT(aud: aud, azp: nil, nonce: nil, maxAge: nil, authTime: nil, orgName: orgID)
let context = IDTokenValidatorContext(issuer: validatorContext.issuer,
audience: aud[0],
jwksRequest: validatorContext.jwksRequest,
leeway: validatorContext.leeway,
maxAge: nil,
nonce: nil,
organization: orgID)
let expectedError = IDTokenOrgIDValidator.ValidationError.missingOrgId

await waitUntil { done in
validate(idToken: jwt.string,
with: context,
signatureValidator: mockSignatureValidator) { error in
expect(error).to(matchError(expectedError))
done()
}
}
}

it("should expect an organization name instead of an organization ID") {
let orgName = "abc1234"
let jwt = generateJWT(aud: aud, azp: nil, nonce: nil, maxAge: nil, authTime: nil, orgID: orgName)
let context = IDTokenValidatorContext(issuer: validatorContext.issuer,
audience: aud[0],
jwksRequest: validatorContext.jwksRequest,
leeway: validatorContext.leeway,
maxAge: nil,
nonce: nil,
organization: orgName)
let expectedError = IDTokenOrgNameValidator.ValidationError.missingOrgName

await waitUntil { done in
validate(idToken: jwt.string,
with: context,
signatureValidator: mockSignatureValidator) { error in
expect(error).to(matchError(expectedError))
done()
}
}
}
}

}
Expand Down
6 changes: 3 additions & 3 deletions EXAMPLES.md
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,7 @@ Auth0
```swift
Auth0
.webAuth()
.organization("YOUR_AUTH0_ORGANIZATION_ID")
.organization("YOUR_AUTH0_ORGANIZATION_NAME_OR_ID")
.start { result in
switch result {
case .success(let credentials):
Expand All @@ -1230,7 +1230,7 @@ Auth0
do {
let credentials = try await Auth0
.webAuth()
.organization("YOUR_AUTH0_ORGANIZATION_ID")
.organization("YOUR_AUTH0_ORGANIZATION_NAME_OR_ID")
.start()
print("Obtained credentials: \(credentials)")
} catch {
Expand All @@ -1245,7 +1245,7 @@ do {
```swift
Auth0
.webAuth()
.organization("YOUR_AUTH0_ORGANIZATION_ID")
.organization("YOUR_AUTH0_ORGANIZATION_NAME_OR_ID")
.start()
.sink(receiveCompletion: { completion in
if case .failure(let error) = completion {
Expand Down