Skip to content

Commit

Permalink
fix swift 6 warnings - thread safe tokenizer and model config (#126)
Browse files Browse the repository at this point in the history
- replaces #125 with simpler mechanism (NSLock)

Co-authored-by: John Mai <[email protected]>
  • Loading branch information
davidkoski and johnmai-dev authored Sep 30, 2024
1 parent ee94992 commit caa5caf
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 22 deletions.
49 changes: 38 additions & 11 deletions Libraries/LLM/Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,21 @@ public enum StringOrNumber: Codable, Equatable, Sendable {
}
}

public struct ModelType: RawRepresentable, Codable, Sendable {
public let rawValue: String
private class ModelTypeRegistry: @unchecked Sendable {

public init(rawValue: String) {
self.rawValue = rawValue
}
// Note: using NSLock as we have very small (just dictionary get/set)
// critical sections and expect no contention. this allows the methods
// to remain synchronous.
private let lock = NSLock()

@Sendable
private static func createLlamaModel(url: URL) throws -> LLMModel {
let configuration = try JSONDecoder().decode(
LlamaConfiguration.self, from: Data(contentsOf: url))
return LlamaModel(configuration)
}

private static var creators: [String: (URL) throws -> LLMModel] = [
private var creators: [String: @Sendable (URL) throws -> LLMModel] = [
"mistral": createLlamaModel,
"llama": createLlamaModel,
"phi": { url in
Expand Down Expand Up @@ -89,18 +90,44 @@ public struct ModelType: RawRepresentable, Codable, Sendable {
},
]

public static func registerModelType(
_ type: String, creator: @escaping (URL) throws -> LLMModel
public func registerModelType(
_ type: String, creator: @Sendable @escaping (URL) throws -> LLMModel
) {
creators[type] = creator
lock.withLock {
creators[type] = creator
}
}

public func createModel(configuration: URL) throws -> LLMModel {
guard let creator = ModelType.creators[rawValue] else {
public func createModel(configuration: URL, rawValue: String) throws -> LLMModel {
let creator = lock.withLock {
creators[rawValue]
}
guard let creator else {
throw LLMError(message: "Unsupported model type.")
}
return try creator(configuration)
}

}

private let modelTypeRegistry = ModelTypeRegistry()

public struct ModelType: RawRepresentable, Codable, Sendable {
public let rawValue: String

public init(rawValue: String) {
self.rawValue = rawValue
}

public static func registerModelType(
_ type: String, creator: @Sendable @escaping (URL) throws -> LLMModel
) {
modelTypeRegistry.registerModelType(type, creator: creator)
}

public func createModel(configuration: URL) throws -> LLMModel {
try modelTypeRegistry.createModel(configuration: configuration, rawValue: rawValue)
}
}

public struct BaseConfiguration: Codable, Sendable {
Expand Down
35 changes: 29 additions & 6 deletions Libraries/LLM/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,35 @@ private func updateTokenizerConfig(_ tokenizerConfig: Config) -> Config {
return tokenizerConfig
}

/// overrides for TokenizerModel/knownTokenizers
public var replacementTokenizers = [
"InternLM2Tokenizer": "PreTrainedTokenizer",
"Qwen2Tokenizer": "PreTrainedTokenizer",
"CohereTokenizer": "PreTrainedTokenizer",
]
public class TokenizerReplacementRegistry: @unchecked Sendable {

// Note: using NSLock as we have very small (just dictionary get/set)
// critical sections and expect no contention. this allows the methods
// to remain synchronous.
private let lock = NSLock()

/// overrides for TokenizerModel/knownTokenizers
private var replacementTokenizers = [
"InternLM2Tokenizer": "PreTrainedTokenizer",
"Qwen2Tokenizer": "PreTrainedTokenizer",
"CohereTokenizer": "PreTrainedTokenizer",
]

public subscript(key: String) -> String? {
get {
lock.withLock {
replacementTokenizers[key]
}
}
set {
lock.withLock {
replacementTokenizers[key] = newValue
}
}
}
}

public let replacementTokenizers = TokenizerReplacementRegistry()

public protocol StreamingDetokenizer: IteratorProtocol<String> {

Expand Down
2 changes: 1 addition & 1 deletion Tools/LinearModelTraining/LinearModelTraining.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import MLXNN
import MLXOptimizers
import MLXRandom

#if swift(>=6.0)
#if swift(>=5.10)
extension MLX.DeviceType: @retroactive ExpressibleByArgument {
public init?(argument: String) {
self.init(rawValue: argument)
Expand Down
2 changes: 1 addition & 1 deletion Tools/image-tool/Arguments.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import ArgumentParser
import Foundation
import MLX

#if swift(>=6.0)
#if swift(>=5.10)
/// Extension to allow URL command line arguments.
extension URL: @retroactive ExpressibleByArgument {
public init?(argument: String) {
Expand Down
2 changes: 1 addition & 1 deletion Tools/image-tool/ImageTool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct StableDiffusionTool: AsyncParsableCommand {
)
}

#if swift(>=6.0)
#if swift(>=5.10)
extension StableDiffusionConfiguration.Preset: @retroactive ExpressibleByArgument {}
#else
extension StableDiffusionConfiguration.Preset: ExpressibleByArgument {}
Expand Down
2 changes: 1 addition & 1 deletion Tools/llm-tool/Arguments.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import ArgumentParser
import Foundation

/// Extension to allow URL command line arguments.
#if swift(>=6.0)
#if swift(>=5.10)
extension URL: @retroactive ExpressibleByArgument {
public init?(argument: String) {
if argument.contains("://") {
Expand Down
2 changes: 1 addition & 1 deletion Tools/mnist-tool/MNISTTool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct MNISTTool: AsyncParsableCommand {
defaultSubcommand: Train.self)
}

#if swift(>=6.0)
#if swift(>=5.10)
extension MLX.DeviceType: @retroactive ExpressibleByArgument {
public init?(argument: String) {
self.init(rawValue: argument)
Expand Down

0 comments on commit caa5caf

Please sign in to comment.