From caa5caf4ca64e79c3ad8f64e2a49f9b85ef1bc19 Mon Sep 17 00:00:00 2001 From: David Koski <46639364+davidkoski@users.noreply.github.com> Date: Mon, 30 Sep 2024 14:21:58 -0700 Subject: [PATCH] fix swift 6 warnings - thread safe tokenizer and model config (#126) - replaces #125 with simpler mechanism (NSLock) Co-authored-by: John Mai --- Libraries/LLM/Configuration.swift | 49 ++++++++++++++----- Libraries/LLM/Tokenizer.swift | 35 ++++++++++--- .../LinearModelTraining.swift | 2 +- Tools/image-tool/Arguments.swift | 2 +- Tools/image-tool/ImageTool.swift | 2 +- Tools/llm-tool/Arguments.swift | 2 +- Tools/mnist-tool/MNISTTool.swift | 2 +- 7 files changed, 72 insertions(+), 22 deletions(-) diff --git a/Libraries/LLM/Configuration.swift b/Libraries/LLM/Configuration.swift index 1d7f585..f432af4 100644 --- a/Libraries/LLM/Configuration.swift +++ b/Libraries/LLM/Configuration.swift @@ -26,20 +26,21 @@ public enum StringOrNumber: Codable, Equatable, Sendable { } } -public struct ModelType: RawRepresentable, Codable, Sendable { - public let rawValue: String +private class ModelTypeRegistry: @unchecked Sendable { - public init(rawValue: String) { - self.rawValue = rawValue - } + // Note: using NSLock as we have very small (just dictionary get/set) + // critical sections and expect no contention. this allows the methods + // to remain synchronous. + private let lock = NSLock() + @Sendable private static func createLlamaModel(url: URL) throws -> LLMModel { let configuration = try JSONDecoder().decode( LlamaConfiguration.self, from: Data(contentsOf: url)) return LlamaModel(configuration) } - private static var creators: [String: (URL) throws -> LLMModel] = [ + private var creators: [String: @Sendable (URL) throws -> LLMModel] = [ "mistral": createLlamaModel, "llama": createLlamaModel, "phi": { url in @@ -89,18 +90,44 @@ public struct ModelType: RawRepresentable, Codable, Sendable { }, ] - public static func registerModelType( - _ type: String, creator: @escaping (URL) throws -> LLMModel + public func registerModelType( + _ type: String, creator: @Sendable @escaping (URL) throws -> LLMModel ) { - creators[type] = creator + lock.withLock { + creators[type] = creator + } } - public func createModel(configuration: URL) throws -> LLMModel { - guard let creator = ModelType.creators[rawValue] else { + public func createModel(configuration: URL, rawValue: String) throws -> LLMModel { + let creator = lock.withLock { + creators[rawValue] + } + guard let creator else { throw LLMError(message: "Unsupported model type.") } return try creator(configuration) } + +} + +private let modelTypeRegistry = ModelTypeRegistry() + +public struct ModelType: RawRepresentable, Codable, Sendable { + public let rawValue: String + + public init(rawValue: String) { + self.rawValue = rawValue + } + + public static func registerModelType( + _ type: String, creator: @Sendable @escaping (URL) throws -> LLMModel + ) { + modelTypeRegistry.registerModelType(type, creator: creator) + } + + public func createModel(configuration: URL) throws -> LLMModel { + try modelTypeRegistry.createModel(configuration: configuration, rawValue: rawValue) + } } public struct BaseConfiguration: Codable, Sendable { diff --git a/Libraries/LLM/Tokenizer.swift b/Libraries/LLM/Tokenizer.swift index a21da89..e79782a 100644 --- a/Libraries/LLM/Tokenizer.swift +++ b/Libraries/LLM/Tokenizer.swift @@ -67,12 +67,35 @@ private func updateTokenizerConfig(_ tokenizerConfig: Config) -> Config { return tokenizerConfig } -/// overrides for TokenizerModel/knownTokenizers -public var replacementTokenizers = [ - "InternLM2Tokenizer": "PreTrainedTokenizer", - "Qwen2Tokenizer": "PreTrainedTokenizer", - "CohereTokenizer": "PreTrainedTokenizer", -] +public class TokenizerReplacementRegistry: @unchecked Sendable { + + // Note: using NSLock as we have very small (just dictionary get/set) + // critical sections and expect no contention. this allows the methods + // to remain synchronous. + private let lock = NSLock() + + /// overrides for TokenizerModel/knownTokenizers + private var replacementTokenizers = [ + "InternLM2Tokenizer": "PreTrainedTokenizer", + "Qwen2Tokenizer": "PreTrainedTokenizer", + "CohereTokenizer": "PreTrainedTokenizer", + ] + + public subscript(key: String) -> String? { + get { + lock.withLock { + replacementTokenizers[key] + } + } + set { + lock.withLock { + replacementTokenizers[key] = newValue + } + } + } +} + +public let replacementTokenizers = TokenizerReplacementRegistry() public protocol StreamingDetokenizer: IteratorProtocol { diff --git a/Tools/LinearModelTraining/LinearModelTraining.swift b/Tools/LinearModelTraining/LinearModelTraining.swift index fb10704..17439fb 100644 --- a/Tools/LinearModelTraining/LinearModelTraining.swift +++ b/Tools/LinearModelTraining/LinearModelTraining.swift @@ -7,7 +7,7 @@ import MLXNN import MLXOptimizers import MLXRandom -#if swift(>=6.0) +#if swift(>=5.10) extension MLX.DeviceType: @retroactive ExpressibleByArgument { public init?(argument: String) { self.init(rawValue: argument) diff --git a/Tools/image-tool/Arguments.swift b/Tools/image-tool/Arguments.swift index 0d67700..36380fc 100644 --- a/Tools/image-tool/Arguments.swift +++ b/Tools/image-tool/Arguments.swift @@ -4,7 +4,7 @@ import ArgumentParser import Foundation import MLX -#if swift(>=6.0) +#if swift(>=5.10) /// Extension to allow URL command line arguments. extension URL: @retroactive ExpressibleByArgument { public init?(argument: String) { diff --git a/Tools/image-tool/ImageTool.swift b/Tools/image-tool/ImageTool.swift index 9b02786..bd6aea4 100644 --- a/Tools/image-tool/ImageTool.swift +++ b/Tools/image-tool/ImageTool.swift @@ -22,7 +22,7 @@ struct StableDiffusionTool: AsyncParsableCommand { ) } -#if swift(>=6.0) +#if swift(>=5.10) extension StableDiffusionConfiguration.Preset: @retroactive ExpressibleByArgument {} #else extension StableDiffusionConfiguration.Preset: ExpressibleByArgument {} diff --git a/Tools/llm-tool/Arguments.swift b/Tools/llm-tool/Arguments.swift index 281dcd6..a2d1d49 100644 --- a/Tools/llm-tool/Arguments.swift +++ b/Tools/llm-tool/Arguments.swift @@ -4,7 +4,7 @@ import ArgumentParser import Foundation /// Extension to allow URL command line arguments. -#if swift(>=6.0) +#if swift(>=5.10) extension URL: @retroactive ExpressibleByArgument { public init?(argument: String) { if argument.contains("://") { diff --git a/Tools/mnist-tool/MNISTTool.swift b/Tools/mnist-tool/MNISTTool.swift index c401031..7557ec7 100644 --- a/Tools/mnist-tool/MNISTTool.swift +++ b/Tools/mnist-tool/MNISTTool.swift @@ -16,7 +16,7 @@ struct MNISTTool: AsyncParsableCommand { defaultSubcommand: Train.self) } -#if swift(>=6.0) +#if swift(>=5.10) extension MLX.DeviceType: @retroactive ExpressibleByArgument { public init?(argument: String) { self.init(rawValue: argument)