diff --git a/Sources/NetShears/URLProtocol/NetwrokListenerUrlProtocol.swift b/Sources/NetShears/URLProtocol/NetwrokListenerUrlProtocol.swift index edb4441..2f5e88c 100644 --- a/Sources/NetShears/URLProtocol/NetwrokListenerUrlProtocol.swift +++ b/Sources/NetShears/URLProtocol/NetwrokListenerUrlProtocol.swift @@ -11,27 +11,28 @@ class NetwrokListenerUrlProtocol: URLProtocol { struct Constants { static let RequestHandledKey = "NetworkListenerUrlProtocol" + static let RequestID = "request-id" } var session: URLSession? - var sessionTask: URLSessionDataTask? - var currentRequest: NetShearsRequestModel? + var sessionTasks: ThreadSafeDictionary? = .init() + var currentRequests: ThreadSafeDictionary? = .init() + lazy var requestObserver: RequestObserverProtocol = { RequestObserver(options: [ RequestBroadcast.shared ]) }() - + override init(request: URLRequest, cachedResponse: CachedURLResponse?, client: URLProtocolClient?) { super.init(request: request, cachedResponse: cachedResponse, client: client) - + if session == nil { session = URLSession(configuration: .default, delegate: self, delegateQueue: nil) } } - + override class func canInit(with request: URLRequest) -> Bool { - if NetwrokListenerUrlProtocol.property(forKey: Constants.RequestHandledKey, in: request) != nil { return false } @@ -39,33 +40,41 @@ class NetwrokListenerUrlProtocol: URLProtocol { } override class func canonicalRequest(for request: URLRequest) -> URLRequest { - return request + var newRequest = request + let id = UUID().uuidString + newRequest.addValue(id, forHTTPHeaderField: Constants.RequestID) + return newRequest } override func startLoading() { let newRequest = ((request as NSURLRequest).mutableCopy() as? NSMutableURLRequest)! NetwrokListenerUrlProtocol.setProperty(true, forKey: Constants.RequestHandledKey, in: newRequest) - sessionTask = session?.dataTask(with: newRequest as URLRequest) - sessionTask?.resume() - - currentRequest = NetShearsRequestModel(request: newRequest, session: session) - if let request = currentRequest { + guard let id = newRequest.value(forHTTPHeaderField: Constants.RequestID) else { return } + sessionTasks?[id] = session?.dataTask(with: newRequest as URLRequest) + sessionTasks?[id]?.resume() + + currentRequests?[id] = NetShearsRequestModel(request: newRequest, session: session) + if let request = currentRequests?[id] { requestObserver.newRequestArrived(request) } } override func stopLoading() { - sessionTask?.cancel() - currentRequest?.httpBody = body(from: request) - if let startDate = currentRequest?.date{ - currentRequest?.duration = fabs(startDate.timeIntervalSinceNow) * 1000 //Find elapsed time and convert to milliseconds + guard let id = request.value(forHTTPHeaderField: Constants.RequestID) else { return } + sessionTasks?[id]?.cancel() + currentRequests?[id]?.httpBody = body(from: request) + + if let startDate = currentRequests?[id]?.date { + currentRequests?[id]?.duration = fabs(startDate.timeIntervalSinceNow) * 1000 //Find elapsed time and convert to milliseconds } - currentRequest?.isFinished = true - - if let request = currentRequest { + currentRequests?[id]?.isFinished = true + + if let request = currentRequests?[id] { requestObserver.newRequestArrived(request) } session?.invalidateAndCancel() + sessionTasks?[id] = nil + currentRequests?[id] = nil } private func body(from request: URLRequest) -> Data? { @@ -77,23 +86,28 @@ class NetwrokListenerUrlProtocol: URLProtocol { deinit { session = nil - sessionTask = nil - currentRequest = nil + sessionTasks = .init() + currentRequests = .init() } } extension NetwrokListenerUrlProtocol: URLSessionDataDelegate { func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) { client?.urlProtocol(self, didLoad: data) - if currentRequest?.dataResponse == nil{ + guard let id = request.value(forHTTPHeaderField: Constants.RequestID) else { return } + let currentRequest = currentRequests?[id] + + if currentRequest?.dataResponse == nil { currentRequest?.dataResponse = data - } - else{ + } else { currentRequest?.dataResponse?.append(data) } } func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive response: URLResponse, completionHandler: @escaping (URLSession.ResponseDisposition) -> Void) { + guard let id = request.value(forHTTPHeaderField: Constants.RequestID) else { return } + let currentRequest = currentRequests?[id] + let policy = URLCache.StoragePolicy(rawValue: request.cachePolicy.rawValue) ?? .notAllowed client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: policy) currentRequest?.initResponse(response: response) @@ -102,6 +116,8 @@ extension NetwrokListenerUrlProtocol: URLSessionDataDelegate { func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { if let error = error { + guard let id = request.value(forHTTPHeaderField: Constants.RequestID) else { return } + let currentRequest = currentRequests?[id] currentRequest?.errorClientDescription = error.localizedDescription client?.urlProtocol(self, didFailWithError: error) } else { @@ -116,6 +132,8 @@ extension NetwrokListenerUrlProtocol: URLSessionDataDelegate { func urlSession(_ session: URLSession, didBecomeInvalidWithError error: Error?) { guard let error = error else { return } + guard let id = request.value(forHTTPHeaderField: Constants.RequestID) else { return } + let currentRequest = currentRequests?[id] currentRequest?.errorClientDescription = error.localizedDescription client?.urlProtocol(self, didFailWithError: error) } diff --git a/Sources/NetShears/Utils/ThreadSafe.swift b/Sources/NetShears/Utils/ThreadSafe.swift index 12c0853..2714374 100644 --- a/Sources/NetShears/Utils/ThreadSafe.swift +++ b/Sources/NetShears/Utils/ThreadSafe.swift @@ -24,3 +24,21 @@ final class ThreadSafe { } } } + +final class ThreadSafeDictionary { + private var dictionary: [K: V] = [:] + private var queue = DispatchQueue(label: UUID().uuidString, attributes: .concurrent) + + final subscript(_ key: K) -> V? { + get { + queue.sync { + dictionary[key] + } + } + set(value) { + queue.sync(flags: .barrier) { + dictionary[key] = value + } + } + } +}