diff --git a/Sources/Generation/MLMultiArray+Utils.swift b/Sources/Generation/CoreML+Extensions.swift similarity index 82% rename from Sources/Generation/MLMultiArray+Utils.swift rename to Sources/Generation/CoreML+Extensions.swift index f3592233..35f627b3 100644 --- a/Sources/Generation/MLMultiArray+Utils.swift +++ b/Sources/Generation/CoreML+Extensions.swift @@ -1,5 +1,5 @@ // -// MLMultiArray+Utils.swift +// CoreML+Extensions.swift // CoreMLBert // // Created by Julien Chaumond on 27/06/2019. @@ -10,7 +10,7 @@ import CoreML import Foundation -public extension MLMultiArray { +extension MLMultiArray { /// All values will be stored in the last dimension of the MLMultiArray (default is dims=1) static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray { var shape = Array(repeating: 1, count: dims) @@ -88,7 +88,7 @@ public extension MLMultiArray { } } -public extension MLMultiArray { +extension MLMultiArray { /// Provides a way to index n-dimensionals arrays a la numpy. enum Indexing: Equatable { case select(Int) @@ -197,4 +197,48 @@ extension MLMultiArray { return s + "]" } } + +extension MLShapedArray { + var floats: [Float] { + guard strides.first == 1, strides.count == 1 else { + // For some reason this path is slow. + // If strides is not 1, we can write a Metal kernel to copy the values properly. + return scalars + } + + // Fast path: memcpy + let mlArray = MLMultiArray(self) + return mlArray.floats ?? scalars + } +} + +extension MLShapedArraySlice { + var floats: [Float] { + guard strides.first == 1, strides.count == 1 else { + // For some reason this path is slow. + // If strides is not 1, we can write a Metal kernel to copy the values properly. + return scalars + } + + // Fast path: memcpy + let mlArray = MLMultiArray(self) + return mlArray.floats ?? scalars + } +} + +extension MLMultiArray { + var floats: [Float]? { + guard dataType == .float32 else { return nil } + + var result: [Float] = Array(repeating: 0, count: count) + return withUnsafeBytes { ptr in + guard let source = ptr.baseAddress else { return nil } + result.withUnsafeMutableBytes { resultPtr in + let dest = resultPtr.baseAddress! + memcpy(dest, source, self.count * MemoryLayout.stride) + } + return result + } + } +} #endif // canImport(CoreML) diff --git a/Sources/Generation/MLShapedArray+Utils.swift b/Sources/Generation/MLShapedArray+Utils.swift deleted file mode 100644 index fd25944d..00000000 --- a/Sources/Generation/MLShapedArray+Utils.swift +++ /dev/null @@ -1,54 +0,0 @@ -// -// MLShapedArray+Utils.swift -// -// -// Created by Pedro Cuenca on 13/5/23. -// - -#if canImport(CoreML) -import CoreML - -public extension MLShapedArray { - var floats: [Float] { - guard strides.first == 1, strides.count == 1 else { - // For some reason this path is slow. - // If strides is not 1, we can write a Metal kernel to copy the values properly. - return scalars - } - - // Fast path: memcpy - let mlArray = MLMultiArray(self) - return mlArray.floats ?? scalars - } -} - -public extension MLShapedArraySlice { - var floats: [Float] { - guard strides.first == 1, strides.count == 1 else { - // For some reason this path is slow. - // If strides is not 1, we can write a Metal kernel to copy the values properly. - return scalars - } - - // Fast path: memcpy - let mlArray = MLMultiArray(self) - return mlArray.floats ?? scalars - } -} - -public extension MLMultiArray { - var floats: [Float]? { - guard dataType == .float32 else { return nil } - - var result: [Float] = Array(repeating: 0, count: count) - return withUnsafeBytes { ptr in - guard let source = ptr.baseAddress else { return nil } - result.withUnsafeMutableBytes { resultPtr in - let dest = resultPtr.baseAddress! - memcpy(dest, source, self.count * MemoryLayout.stride) - } - return result - } - } -} -#endif // canImport(CoreML) diff --git a/Sources/Hub/BOMDoubling.swift b/Sources/Hub/Extensions/JSONSerialization+BOM.swift similarity index 96% rename from Sources/Hub/BOMDoubling.swift rename to Sources/Hub/Extensions/JSONSerialization+BOM.swift index 9ed906f0..84651259 100644 --- a/Sources/Hub/BOMDoubling.swift +++ b/Sources/Hub/Extensions/JSONSerialization+BOM.swift @@ -1,5 +1,5 @@ // -// BOMDoubling.swift +// JSONSerialization+BOM.swift // swift-transformers // // Created by Pedro Cuenca on 20250912 @@ -7,7 +7,13 @@ import Foundation -extension Data { +extension JSONSerialization { + class func bomPreservingJsonObject(with data: Data, options: JSONSerialization.ReadingOptions = []) throws -> Any { + try JSONSerialization.jsonObject(with: data.duplicatingBOMsAfterQuotes, options: options) + } +} + +private extension Data { /// Workaround for https://github.com/huggingface/swift-transformers/issues/116 /// Duplicate a BOM sequence that follows a quote. The first BOM is swallowed by JSONSerialization.jsonObject /// because it thinks it marks the encoding. @@ -40,9 +46,3 @@ extension Data { } } } - -extension JSONSerialization { - class func bomPreservingJsonObject(with data: Data, options: JSONSerialization.ReadingOptions = []) throws -> Any { - try JSONSerialization.jsonObject(with: data.duplicatingBOMsAfterQuotes, options: options) - } -} diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index 7def999e..59ec26ae 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -891,13 +891,13 @@ public extension Hub { } } -public extension [String] { +private extension [String] { func matching(glob: String) -> [String] { filter { fnmatch(glob, $0, 0) == 0 } } } -public extension FileManager { +private extension FileManager { func getFileUrls(at directoryUrl: URL) throws -> [URL] { var fileUrls = [URL]() diff --git a/Sources/Tokenizers/BertTokenizer.swift b/Sources/Tokenizers/BertTokenizer.swift index 30d02563..92bfa373 100644 --- a/Sources/Tokenizers/BertTokenizer.swift +++ b/Sources/Tokenizers/BertTokenizer.swift @@ -226,7 +226,7 @@ class BasicTokenizer { } } -extension Character { +private extension Character { /// https://github.com/huggingface/transformers/blob/8c1b5d37827a6691fef4b2d926f2d04fb6f5a9e3/src/transformers/tokenization_utils.py#L367 var isExtendedPunctuation: Bool { if isPunctuation { return true } diff --git a/Sources/Tokenizers/Decoder.swift b/Sources/Tokenizers/Decoder.swift index 217297d6..a7b0b15e 100644 --- a/Sources/Tokenizers/Decoder.swift +++ b/Sources/Tokenizers/Decoder.swift @@ -236,7 +236,7 @@ class MetaspaceDecoder: Decoder { } /// We could use firstIndex(where:), lastIndex(where:) for possibly better efficiency (and do both ends at once) -public extension String { +private extension String { func trimmingFromStart(character: Character = " ", upto: Int) -> String { var result = self var trimmed = 0 diff --git a/Sources/Tokenizers/PreTokenizer.swift b/Sources/Tokenizers/PreTokenizer.swift index 29918b21..92df575d 100644 --- a/Sources/Tokenizers/PreTokenizer.swift +++ b/Sources/Tokenizers/PreTokenizer.swift @@ -238,164 +238,3 @@ class SplitPreTokenizer: PreTokenizer { return pattern.split(text, invert: invert) } } - -enum StringSplitPattern { - case regexp(regexp: String) - case string(pattern: String) -} - -extension StringSplitPattern { - func split(_ text: String, invert: Bool = true) -> [String] { - switch self { - case let .regexp(regexp): - text.split(by: regexp, includeSeparators: true) - case let .string(substring): - text.split(by: substring, options: [], includeSeparators: !invert) - } - } -} - -extension StringSplitPattern { - static func from(config: Config) -> StringSplitPattern? { - if let pattern = config.pattern.String.string() { - return StringSplitPattern.string(pattern: pattern) - } - if let pattern = config.pattern.Regex.string() { - return StringSplitPattern.regexp(regexp: pattern) - } - return nil - } -} - -public extension String { - func ranges(of string: String, options: CompareOptions = .regularExpression) -> [Range] { - var result: [Range] = [] - var start = startIndex - while let range = range(of: string, options: options, range: start.. [String] { - var result: [String] = [] - var start = startIndex - while let range = range(of: string, options: options, range: start.. [String] { - // Find the matching capture groups - let selfRange = NSRange(startIndex.. - // https://stackoverflow.com/questions/75543272/convert-a-given-utf8-nsrange-in-a-string-to-a-utf16-nsrange - guard let matchRange = Range(match.range, in: self) else { continue } - - // Add text before the match - if start < matchRange.lowerBound { - result.append(String(self[start.. [String] { - func mergedWithNext(ranges: [Range]) -> [Range] { - var merged: [Range] = [] - var currentStart = startIndex - for range in ranges { - if range.lowerBound == startIndex { continue } - let mergedRange = currentStart..]) -> [Range] { - var merged: [Range] = [] - var currentStart = startIndex - for range in ranges { - let mergedRange = currentStart.. (3, 4), (9, 10), (10, 11) -> (start, 2), (3, 8), (9, 9), (10, end) - let ranges = ranges(of: string, options: options) - let merged = mergedWithNext(ranges: ranges) - return merged.map { String(self[$0]) } - case .mergedWithPrevious: - // Obtain ranges and merge them - // "the-final--countdown" -> (3, 4), (9, 10), (10, 11) -> (start, 3), (4, 9), (10, 10), (11, end) - let ranges = ranges(of: string, options: options) - let merged = mergedWithPrevious(ranges: ranges) - return merged.map { String(self[$0]) } - } - } -} diff --git a/Sources/Tokenizers/String+PreTokenization.swift b/Sources/Tokenizers/String+PreTokenization.swift new file mode 100644 index 00000000..54f875d0 --- /dev/null +++ b/Sources/Tokenizers/String+PreTokenization.swift @@ -0,0 +1,157 @@ +import Foundation +import struct Hub.Config + +enum StringSplitPattern { + case regexp(regexp: String) + case string(pattern: String) + + func split(_ text: String, invert: Bool = true) -> [String] { + switch self { + case let .regexp(regexp): + text.split(by: regexp, includeSeparators: true) + case let .string(substring): + text.split(by: substring, options: [], includeSeparators: !invert) + } + } + + static func from(config: Config) -> StringSplitPattern? { + if let pattern = config.pattern.String.string() { + return .string(pattern: pattern) + } + if let pattern = config.pattern.Regex.string() { + return .regexp(regexp: pattern) + } + return nil + } +} + +enum SplitDelimiterBehavior { + case removed + case isolated + case mergedWithPrevious + case mergedWithNext +} + +extension String { + func ranges(of string: String, options: CompareOptions = .regularExpression) -> [Range] { + var result: [Range] = [] + var start = startIndex + while let range = range(of: string, options: options, range: start.. [String] { + var result: [String] = [] + var start = startIndex + while let range = range(of: string, options: options, range: start.. [String] { + // Find the matching capture groups + let selfRange = NSRange(startIndex.. + // https://stackoverflow.com/questions/75543272/convert-a-given-utf8-nsrange-in-a-string-to-a-utf16-nsrange + guard let matchRange = Range(match.range, in: self) else { continue } + + // Add text before the match + if start < matchRange.lowerBound { + result.append(String(self[start.. [String] { + func mergedWithNext(ranges: [Range]) -> [Range] { + var merged: [Range] = [] + var currentStart = startIndex + for range in ranges { + if range.lowerBound == startIndex { continue } + let mergedRange = currentStart..]) -> [Range] { + var merged: [Range] = [] + var currentStart = startIndex + for range in ranges { + let mergedRange = currentStart.. (3, 4), (9, 10), (10, 11) -> (start, 2), (3, 8), (9, 9), (10, end) + let ranges = ranges(of: string, options: options) + let merged = mergedWithNext(ranges: ranges) + return merged.map { String(self[$0]) } + case .mergedWithPrevious: + // Obtain ranges and merge them + // "the-final--countdown" -> (3, 4), (9, 10), (10, 11) -> (start, 3), (4, 9), (10, 10), (11, end) + let ranges = ranges(of: string, options: options) + let merged = mergedWithPrevious(ranges: ranges) + return merged.map { String(self[$0]) } + } + } +} diff --git a/Tests/TokenizersTests/ChatTemplateTests.swift b/Tests/TokenizersTests/ChatTemplateTests.swift index 15d74c4f..2e0e57f3 100644 --- a/Tests/TokenizersTests/ChatTemplateTests.swift +++ b/Tests/TokenizersTests/ChatTemplateTests.swift @@ -7,7 +7,7 @@ import Foundation import Testing -import Tokenizers +@testable import Tokenizers @Suite("Chat Template Tests") struct ChatTemplateTests { diff --git a/Tests/TokenizersTests/FactoryTests.swift b/Tests/TokenizersTests/FactoryTests.swift index b5bf1ebc..ad8cecbb 100644 --- a/Tests/TokenizersTests/FactoryTests.swift +++ b/Tests/TokenizersTests/FactoryTests.swift @@ -8,7 +8,7 @@ import Foundation import Hub import Testing -import Tokenizers +@testable import Tokenizers private func makeHubApi() -> (api: HubApi, downloadDestination: URL) { let base = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first! diff --git a/Tests/TokenizersTests/PreTokenizerTests.swift b/Tests/TokenizersTests/PreTokenizerTests.swift index b0cd4ba6..13292a2a 100644 --- a/Tests/TokenizersTests/PreTokenizerTests.swift +++ b/Tests/TokenizersTests/PreTokenizerTests.swift @@ -176,6 +176,70 @@ struct PreTokenizerTests { ) } + @Test("Split behavior merged with previous") + func splitBehaviorMergedWithPrevious() { + #expect( + "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious) == + ["the-", "final-", "-", "countdown"] + ) + + #expect( + "the-final--countdown-".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious) == + ["the-", "final-", "-", "countdown-"] + ) + + #expect( + "the-final--countdown--".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious) == + ["the-", "final-", "-", "countdown-", "-"] + ) + + #expect( + "-the-final--countdown--".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious) == + ["-", "the-", "final-", "-", "countdown-", "-"] + ) + + #expect( + "--the-final--countdown--".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious) == + ["-", "-", "the-", "final-", "-", "countdown-", "-"] + ) + } + + @Test("Split behavior merged with next") + func splitBehaviorMergedWithNext() { + #expect( + "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext) == + ["the", "-final", "-", "-countdown"] + ) + + #expect( + "-the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext) == + ["-the", "-final", "-", "-countdown"] + ) + + #expect( + "--the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext) == + ["-", "-the", "-final", "-", "-countdown"] + ) + + #expect( + "--the-final--countdown-".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext) == + ["-", "-the", "-final", "-", "-countdown", "-"] + ) + } + + @Test("Split behavior other") + func splitBehaviorOther() { + #expect( + "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .isolated) == + ["the", "-", "final", "-", "-", "countdown"] + ) + + #expect( + "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .removed) == + ["the", "final", "countdown"] + ) + } + /// https://github.com/huggingface/tokenizers/pull/1357 @Test("Metaspace pre-tokenizer with prefix space handling") func metaspacePreTokenizer() { diff --git a/Tests/TokenizersTests/SplitTests.swift b/Tests/TokenizersTests/SplitTests.swift deleted file mode 100644 index 5695b683..00000000 --- a/Tests/TokenizersTests/SplitTests.swift +++ /dev/null @@ -1,104 +0,0 @@ -// -// SplitTests.swift -// -// -// Created by Pedro Cuenca on 20240120. -// - -import Foundation -import Testing -import Tokenizers - -@Suite("Split Behavior Tests") -struct SplitTests { - @Test("String splitting with capture groups") - func splitWithCaptureGroups() { - let addedTokensRegexp = #"(<\|end\|>)\s*|(<\|raw\|>)\s*"# - let captureRegex = try! NSRegularExpression(pattern: addedTokensRegexp, options: []) - - #expect( - "eating <|raw|> meat <|end|> That's all".split(by: captureRegex) == ["eating ", "<|raw|>", "meat ", "<|end|>", "That's all"] - ) - - #expect( - "<|raw|>".split(by: captureRegex) == ["<|raw|>"] - ) - - #expect( - "This string doesn't have those separators".split(by: captureRegex) == ["This string doesn't have those separators"] - ) - - #expect( - "start <|end|>".split(by: captureRegex) == ["start ", "<|end|>"] - ) - - #expect( - "start <|end|> ".split(by: captureRegex) == ["start ", "<|end|>"] - ) - - #expect( - "start <|end|> ".split(by: captureRegex) == ["start ", "<|end|>"] - ) - - #expect( - "start <|end|> for real".split(by: captureRegex) == ["start ", "<|end|>", "for real"] - ) - - #expect( - "<|raw|><|end|>".split(by: captureRegex) == ["<|raw|>", "<|end|>"] - ) - } - - @Test("Split behavior merged with previous") - func splitBehaviorMergedWithPrevious() { - #expect( - "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious) == ["the-", "final-", "-", "countdown"] - ) - - #expect( - "the-final--countdown-".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious) == ["the-", "final-", "-", "countdown-"] - ) - - #expect( - "the-final--countdown--".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious) == ["the-", "final-", "-", "countdown-", "-"] - ) - - #expect( - "-the-final--countdown--".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious) == ["-", "the-", "final-", "-", "countdown-", "-"] - ) - - #expect( - "--the-final--countdown--".split(by: "-", options: .caseInsensitive, behavior: .mergedWithPrevious) == ["-", "-", "the-", "final-", "-", "countdown-", "-"] - ) - } - - @Test("Split behavior merged with next") - func splitBehaviorMergedWithNext() { - #expect( - "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext) == ["the", "-final", "-", "-countdown"] - ) - - #expect( - "-the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext) == ["-the", "-final", "-", "-countdown"] - ) - - #expect( - "--the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext) == ["-", "-the", "-final", "-", "-countdown"] - ) - - #expect( - "--the-final--countdown-".split(by: "-", options: .caseInsensitive, behavior: .mergedWithNext) == ["-", "-the", "-final", "-", "-countdown", "-"] - ) - } - - @Test("Split behavior isolated and removed") - func splitBehaviorOther() { - #expect( - "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .isolated) == ["the", "-", "final", "-", "-", "countdown"] - ) - - #expect( - "the-final--countdown".split(by: "-", options: .caseInsensitive, behavior: .removed) == ["the", "final", "countdown"] - ) - } -}