Skip to content

Commit 3697a66

Browse files
committed
Support oauth2 device grant flows in library (lispkit http oauth).
1 parent c26bef2 commit 3697a66

File tree

2 files changed

+298
-3
lines changed

2 files changed

+298
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
//
2+
// OAuth2DeviceGrantLK.swift
3+
// LispKit
4+
//
5+
// Created by Matthias Zenger on 12/07/2024.
6+
// Copyright © 2024 ObjectHub. All rights reserved.
7+
//
8+
9+
import Foundation
10+
import OAuth2
11+
12+
/// https://www.ietf.org/rfc/rfc8628.html
13+
open class OAuth2DeviceGrantLK: OAuth2 {
14+
public private(set) var deviceCode: String? = nil
15+
public private(set) var pollingInterval: Double = 5.0
16+
17+
override open class var grantType: String {
18+
return "urn:ietf:params:oauth:grant-type:device_code"
19+
}
20+
21+
override open class var responseType: String? {
22+
return ""
23+
}
24+
25+
open func deviceAccessTokenRequest(with deviceCode: String) throws -> OAuth2AuthRequest {
26+
guard let clientId = clientConfig.clientId, !clientId.isEmpty else {
27+
throw OAuth2Error.noClientId
28+
}
29+
let req = OAuth2AuthRequest(url: (clientConfig.tokenURL ?? clientConfig.authorizeURL))
30+
req.params["device_code"] = deviceCode
31+
req.params["grant_type"] = type(of: self).grantType
32+
req.params["client_id"] = clientId
33+
return req
34+
}
35+
36+
open func deviceAuthorizationRequest(params: OAuth2StringDict? = nil) throws -> OAuth2AuthRequest {
37+
guard let clientId = clientConfig.clientId, !clientId.isEmpty else {
38+
throw OAuth2Error.noClientId
39+
}
40+
guard let url = clientConfig.deviceAuthorizeURL else {
41+
throw OAuth2Error.noDeviceCodeURL
42+
}
43+
let req = OAuth2AuthRequest(url: url)
44+
req.params["client_id"] = clientId
45+
if let scope = clientConfig.scope {
46+
req.params["scope"] = scope
47+
}
48+
req.add(params: params)
49+
return req
50+
}
51+
52+
open func parseDeviceAuthorizationResponse(data: Data) throws -> OAuth2JSON {
53+
let dict = try parseJSON(data)
54+
return try parseDeviceAuthorizationResponse(params: dict)
55+
}
56+
57+
public final func parseDeviceAuthorizationResponse(params: OAuth2JSON) throws -> OAuth2JSON {
58+
try assureNoErrorInResponse(params)
59+
return params
60+
}
61+
62+
/**
63+
Start the device authorization flow.
64+
65+
- parameter params: Optional key/value pairs to pass during authorize device request
66+
- parameter callback: The callback to call after the device authorization response has been received
67+
*/
68+
public func start(useNonTextualTransmission: Bool = false,
69+
params: OAuth2StringDict? = nil,
70+
queue: DispatchQueue? = .main,
71+
completion: @escaping (DeviceAuthCodes?, Error?) -> Void) {
72+
authorizeDevice(params: params) { result, error in
73+
guard let result else {
74+
if let error {
75+
self.logger?.warn("OAuth2", msg: "Unable to get device code: \(error)")
76+
}
77+
completion(nil, error)
78+
return
79+
}
80+
guard let deviceCode = result["device_code"] as? String,
81+
let userCode = result["user_code"] as? String,
82+
let verificationUri = result["verification_uri"] as? String,
83+
let verificationUrl = URL(string: verificationUri),
84+
let expiresIn = result["expires_in"] as? Int else {
85+
let error = OAuth2Error.generic("The response doesn't contain all required fields.")
86+
self.logger?.warn("OAuth2", msg: String(describing: error))
87+
completion(nil, error)
88+
return
89+
}
90+
var verificationUrlComplete: URL?
91+
if let verificationUriComplete = result["verification_uri_complete"] as? String {
92+
verificationUrlComplete = URL(string: verificationUriComplete)
93+
}
94+
if useNonTextualTransmission, let url = verificationUrlComplete {
95+
do {
96+
try self.authorizer.openAuthorizeURLInBrowser(url)
97+
} catch let error {
98+
completion(nil, error)
99+
}
100+
}
101+
self.deviceCode = deviceCode
102+
self.pollingInterval = result["interval"] as? TimeInterval ?? 5.0
103+
if let queue {
104+
self.getDeviceAccessToken(deviceCode: deviceCode, interval: self.pollingInterval, queue: queue) { params, error in
105+
if let params {
106+
self.didAuthorize(withParameters: params)
107+
}
108+
else if let error {
109+
self.didFail(with: error.asOAuth2Error)
110+
}
111+
}
112+
}
113+
let deviceAuthorization = DeviceAuthCodes(
114+
deviceCode: deviceCode,
115+
userCode: userCode,
116+
verificationUrl: verificationUrl,
117+
verificationUrlComplete: verificationUrlComplete,
118+
expiresIn: expiresIn,
119+
interval: self.pollingInterval)
120+
completion(deviceAuthorization, nil)
121+
}
122+
}
123+
124+
private func authorizeDevice(params: OAuth2StringDict?, completion: @escaping (OAuth2JSON?, Error?) -> Void) {
125+
do {
126+
let post = try deviceAuthorizationRequest(params: params).asURLRequest(for: self)
127+
logger?.debug("OAuth2", msg: "Obtaining device code from \(post.url!)")
128+
129+
perform(request: post) { response in
130+
do {
131+
let data = try response.responseData()
132+
let params = try self.parseDeviceAuthorizationResponse(data: data)
133+
completion(params, nil)
134+
}
135+
catch let error {
136+
completion(nil, error.asOAuth2Error)
137+
}
138+
}
139+
}
140+
catch let error {
141+
completion(nil, error.asOAuth2Error)
142+
}
143+
}
144+
145+
public func getDeviceAccessToken(deviceCode: String,
146+
interval: TimeInterval,
147+
queue: DispatchQueue = .main,
148+
completion: @escaping (OAuth2JSON?, Error?) -> Void) {
149+
do {
150+
let post = try deviceAccessTokenRequest(with: deviceCode).asURLRequest(for: self)
151+
logger?.debug("OAuth2", msg: "Obtaining access token for device with code \(deviceCode) from \(post.url!)")
152+
perform(request: post) { response in
153+
do {
154+
let data = try response.responseData()
155+
let params = try self.parseAccessTokenResponse(data: data)
156+
completion(params, nil)
157+
}
158+
catch let error {
159+
let oaerror = error.asOAuth2Error
160+
161+
if oaerror == .authorizationPending(nil) {
162+
self.logger?.debug("OAuth2", msg: "AuthorizationPending, repeating in \(interval) seconds.")
163+
queue.asyncAfter(deadline: .now() + interval) {
164+
self.getDeviceAccessToken(deviceCode: deviceCode,
165+
interval: interval,
166+
queue: queue,
167+
completion: completion)
168+
}
169+
} else if oaerror == .slowDown(nil) {
170+
let updatedInterval = interval + 5 // The 5 seconds increase is required by the RFC8628 standard (https://www.rfc-editor.org/rfc/rfc8628#section-3.5)
171+
self.logger?.debug("OAuth2", msg: "SlowDown, repeating in \(updatedInterval) seconds.")
172+
queue.asyncAfter(deadline: .now() + updatedInterval) {
173+
self.getDeviceAccessToken(deviceCode: deviceCode,
174+
interval: updatedInterval,
175+
queue: queue,
176+
completion: completion)
177+
}
178+
} else {
179+
completion(nil, oaerror)
180+
}
181+
}
182+
}
183+
}
184+
catch let error {
185+
completion(nil, error.asOAuth2Error)
186+
}
187+
}
188+
}
189+
190+
public struct DeviceAuthCodes {
191+
public let deviceCode: String
192+
public let userCode: String
193+
public let verificationUrl: URL
194+
public let verificationUrlComplete: URL?
195+
public let expiresIn: Int
196+
public let interval: Double
197+
}

Sources/LispKit/Primitives/HTTPAOuthLibrary.swift

+101-3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ public final class HTTPOAuthLibrary: NativeLibrary {
4646
private let clientCredentialsReddit: Symbol
4747
private let passwordGrant: Symbol
4848
private let deviceGrant: Symbol
49+
private let userCode: Symbol
50+
private let verificationUrl: Symbol
51+
private let verificationUrlComplete: Symbol
52+
private let expiresIn: Symbol
53+
private let interval: Symbol
54+
private let deviceCode: Symbol
4955

5056
/// Initialize symbols.
5157
public required init(in context: Context) throws {
@@ -62,6 +68,12 @@ public final class HTTPOAuthLibrary: NativeLibrary {
6268
self.clientCredentialsReddit = context.symbols.intern("client-credentials-reddit")
6369
self.passwordGrant = context.symbols.intern("password-grant")
6470
self.deviceGrant = context.symbols.intern("device-grant")
71+
self.userCode = context.symbols.intern("user-code")
72+
self.verificationUrl = context.symbols.intern("verification-url")
73+
self.verificationUrlComplete = context.symbols.intern("verification-url-complete")
74+
self.expiresIn = context.symbols.intern("expires-in")
75+
self.interval = context.symbols.intern("interval")
76+
self.deviceCode = context.symbols.intern("device-code")
6577
try super.init(in: context)
6678
}
6779

@@ -88,6 +100,7 @@ public final class HTTPOAuthLibrary: NativeLibrary {
88100
self.define(Procedure("oauth2-refresh-token", self.oauth2RefreshToken))
89101
self.define(Procedure("oauth2-forget-tokens!", self.oauth2ForgetTokens))
90102
self.define(Procedure("oauth2-cancel-requests!", self.oauth2ForgetTokens))
103+
self.define(Procedure("oauth2-request-codes", self.oauth2RequestCodes))
91104
self.define(Procedure("oauth2-authorize!", self.oauth2Authorize))
92105
self.define(Procedure("http-request-sign!", self.httpRequestSign))
93106
self.define(Procedure("oauth2-session?", self.isOAuth2Session))
@@ -161,7 +174,7 @@ public final class HTTPOAuthLibrary: NativeLibrary {
161174
case self.passwordGrant:
162175
oauth2 = OAuth2PasswordGrant(settings: settings)
163176
case self.deviceGrant:
164-
oauth2 = OAuth2DeviceGrant(settings: settings)
177+
oauth2 = OAuth2DeviceGrantLK(settings: settings)
165178
default:
166179
throw RuntimeError.custom("error", "unknown flow identifier", [.symbol(flow)])
167180
}
@@ -476,6 +489,21 @@ public final class HTTPOAuthLibrary: NativeLibrary {
476489
return settings
477490
}
478491

492+
private func oauth2Params(from: Expr?) throws -> OAuth2StringDict? {
493+
guard var list = from else {
494+
return nil
495+
}
496+
var dict: OAuth2StringDict = [:]
497+
while case .pair(.pair(let key, let value), let rest) = list {
498+
dict[try key.asString()] = try value.asString()
499+
list = rest
500+
}
501+
guard case .null = list else {
502+
throw RuntimeError.type(from!, expected: [.properListType])
503+
}
504+
return dict
505+
}
506+
479507
private func oauth2AccessToken(expr: Expr) throws -> Expr {
480508
if let token = try self.oauth2(from: expr).oauth2.accessToken {
481509
return .makeString(token)
@@ -533,7 +561,49 @@ public final class HTTPOAuthLibrary: NativeLibrary {
533561
return res
534562
}
535563

536-
private func authorizeHandler(_ f: Future) -> (OAuth2JSON?, OAuth2Error?) -> Void {
564+
private func oauth2RequestCodes(expr: Expr, nonTextual: Expr?, params: Expr?) throws -> Expr {
565+
let oauth2 = try self.oauth2(from: expr)
566+
guard let client = oauth2.oauth2 as? OAuth2DeviceGrantLK else {
567+
throw RuntimeError.custom("error", "expecting oauth2 client for the device-grant flow: ", [expr])
568+
}
569+
let params = params == nil ? [:] : try self.oauth2Params(from: params!)
570+
let f = Future(external: false)
571+
HTTPOAuthLibrary.authRequestManager.register(oauth2: client, result: f, in: self.context)
572+
client.start(useNonTextualTransmission: nonTextual?.isTrue ?? false, params: params, queue: nil) { codes, error in
573+
defer {
574+
HTTPOAuthLibrary.authRequestManager.unregister(future: f, in: self.context)
575+
}
576+
do {
577+
if let error {
578+
_ = try f.setResult(in: self.context, to: .error(RuntimeError.os(error)), raise: true)
579+
} else if let codes {
580+
var res = Expr.null
581+
res = .pair(.pair(.symbol(self.interval), .makeNumber(codes.interval)), res)
582+
res = .pair(.pair(.symbol(self.deviceCode), .makeString(codes.deviceCode)), res)
583+
if let url = codes.verificationUrlComplete {
584+
res = .pair(.pair(.symbol(self.verificationUrlComplete), .makeString(url.absoluteString)), res)
585+
}
586+
res = .pair(.pair(.symbol(self.verificationUrl), .makeString(codes.verificationUrl.absoluteString)), res)
587+
res = .pair(.pair(.symbol(self.expiresIn), .makeNumber(codes.expiresIn)), res)
588+
res = .pair(.pair(.symbol(self.userCode), .makeString(codes.userCode)), res)
589+
_ = try f.setResult(in: self.context, to: res, raise: false)
590+
} else {
591+
_ = try f.setResult(in: self.context,
592+
to: .error(RuntimeError.eval(.serverError)),
593+
raise: true)
594+
}
595+
} catch {
596+
do {
597+
_ = try f.setResult(in: self.context,
598+
to: .error(RuntimeError.eval(.serverError, .object(f))),
599+
raise: true)
600+
} catch {}
601+
}
602+
}
603+
return .object(f)
604+
}
605+
606+
private func authorizeHandler(_ f: Future) -> (OAuth2JSON?, Error?) -> Void {
537607
return { params, error in
538608
defer {
539609
HTTPOAuthLibrary.authRequestManager.unregister(future: f, in: self.context)
@@ -562,7 +632,35 @@ public final class HTTPOAuthLibrary: NativeLibrary {
562632
let oauth2 = try self.oauth2(from: expr)
563633
let f = Future(external: false)
564634
HTTPOAuthLibrary.authRequestManager.register(oauth2: oauth2.oauth2, result: f, in: self.context)
565-
oauth2.oauth2.authorize(callback: self.authorizeHandler(f))
635+
if let oauth2DeviceGrant = oauth2.oauth2 as? OAuth2DeviceGrantLK {
636+
if oauth2DeviceGrant.hasUnexpiredAccessToken() {
637+
var params = Expr.null
638+
params = .pair(.pair(.makeString("token_type"), .makeString("bearer")), params)
639+
if let scope = oauth2DeviceGrant.scope {
640+
params = .pair(.pair(.makeString("scope"), .makeString(scope)), params)
641+
}
642+
if let accessToken = oauth2DeviceGrant.accessToken {
643+
params = .pair(.pair(.makeString("access_token"), .makeString(accessToken)), params)
644+
}
645+
_ = try f.setResult(in: self.context, to: params, raise: false)
646+
} else if let deviceCode = oauth2DeviceGrant.deviceCode {
647+
let callback = self.authorizeHandler(f)
648+
oauth2DeviceGrant.getDeviceAccessToken(deviceCode: deviceCode,
649+
interval: oauth2DeviceGrant.pollingInterval,
650+
queue: .global(qos: .default)) { params, error in
651+
if let params {
652+
oauth2DeviceGrant.didAuthorize(withParameters: params)
653+
} else if let error {
654+
oauth2DeviceGrant.didFail(with: error.asOAuth2Error)
655+
}
656+
callback(params, error)
657+
}
658+
} else {
659+
throw RuntimeError.custom("error", "OAuth2 device grant client did not yet receive device code: ", [expr])
660+
}
661+
} else {
662+
oauth2.oauth2.authorize(callback: self.authorizeHandler(f))
663+
}
566664
return .object(f)
567665
}
568666

0 commit comments

Comments
 (0)