Skip to content

Commit

Permalink
Add ability to prevent config.json being written to `~/Documents/hugg…
Browse files Browse the repository at this point in the history
…ingface/...` (#262)

* `fetchAvailableModels` and `fetchAvailableModels` take `downloadBase: URL?` param

Which controls where HubAPI leaves the config.json that it downloads.

* `recommendedRemoteModels` also gets passed the downloadBase (and repo param)
  • Loading branch information
iandundas authored Nov 19, 2024
1 parent e8eebbe commit 0af7146
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ open class WhisperKit {
return modelSupport(for: deviceName)
}

public static func recommendedRemoteModels(from repo: String = "argmaxinc/whisperkit-coreml") async -> ModelSupport {
public static func recommendedRemoteModels(from repo: String = "argmaxinc/whisperkit-coreml", downloadBase: URL? = nil) async -> ModelSupport {
let deviceName = Self.deviceName()
let config = await Self.fetchModelSupportConfig(from: repo)
let config = await Self.fetchModelSupportConfig(from: repo, downloadBase: downloadBase)
return modelSupport(for: deviceName, from: config)
}

public static func fetchModelSupportConfig(from repo: String = "argmaxinc/whisperkit-coreml") async -> ModelSupportConfig {
let hubApi = HubApi()
public static func fetchModelSupportConfig(from repo: String = "argmaxinc/whisperkit-coreml", downloadBase: URL? = nil) async -> ModelSupportConfig {
let hubApi = HubApi(downloadBase: downloadBase)
var modelSupportConfig = Constants.fallbackModelSupportConfig

do {
Expand All @@ -176,8 +176,8 @@ open class WhisperKit {
return modelSupportConfig
}

public static func fetchAvailableModels(from repo: String = "argmaxinc/whisperkit-coreml", matching: [String] = ["*"]) async throws -> [String] {
let modelSupportConfig = await fetchModelSupportConfig(from: repo)
public static func fetchAvailableModels(from repo: String = "argmaxinc/whisperkit-coreml", matching: [String] = ["*"], downloadBase: URL? = nil) async throws -> [String] {
let modelSupportConfig = await fetchModelSupportConfig(from: repo, downloadBase: downloadBase)
let supportedModels = modelSupportConfig.modelSupport().supported
var filteredSupportSet: Set<String> = []
for glob in matching {
Expand Down Expand Up @@ -290,10 +290,10 @@ open class WhisperKit {
self.modelFolder = URL(fileURLWithPath: folder)
} else if download {
// Determine the model variant to use
let modelSupport = await WhisperKit.recommendedRemoteModels()
let repo = modelRepo ?? "argmaxinc/whisperkit-coreml"
let modelSupport = await WhisperKit.recommendedRemoteModels(from: repo, downloadBase: downloadBase)
let modelVariant = model ?? modelSupport.default

let repo = modelRepo ?? "argmaxinc/whisperkit-coreml"
do {
self.modelFolder = try await Self.download(
variant: modelVariant,
Expand Down

0 comments on commit 0af7146

Please sign in to comment.