Skip to content

Commit 2dafb02

Browse files
authored
withDefaultDevice() (#15)
Expose SHLLM.withDefaultDevice(_:_:) to allow clients to run inference on the CPU or GPU.
1 parent 0811eb0 commit 2dafb02

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

Sources/SHLLM/SHLLM.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,30 @@ public enum SHLLM {
4040
return true
4141
}
4242

43+
public static func withDefaultDevice<R>(
44+
_ device: MLX.DeviceType,
45+
_ body: () throws -> R
46+
) rethrows -> R {
47+
switch device {
48+
case .cpu:
49+
try MLX.Device.withDefaultDevice(.cpu, body)
50+
case .gpu:
51+
try MLX.Device.withDefaultDevice(.gpu, body)
52+
}
53+
}
54+
55+
public static func withDefaultDevice<R>(
56+
_ device: MLX.DeviceType,
57+
_ body: () async throws -> R
58+
) async rethrows -> R {
59+
switch device {
60+
case .cpu:
61+
try await MLX.Device.withDefaultDevice(.cpu, body)
62+
case .gpu:
63+
try await MLX.Device.withDefaultDevice(.gpu, body)
64+
}
65+
}
66+
4367
static var assertSupportedDevice: Void {
4468
get throws {
4569
guard isSupportedDevice else {
@@ -60,6 +84,8 @@ public enum SHLLM {
6084

6185
extension Chat.Message: @retroactive @unchecked Sendable {}
6286

87+
@_exported import enum MLX.DeviceType
88+
6389
@_exported import protocol MLXLMCommon.LanguageModel
6490

6591
@_exported import class MLXLLM.Gemma2Model

Tests/SHLLMTests/SHLLMTests.swift

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,70 @@ func recommendedMaxWorkingSetSize() async throws {
2727
let recommended = SHLLM.recommendedMaxWorkingSetSize
2828
#expect(recommended > 0)
2929
}
30+
31+
// NOTE: Running inference on the CPU takes way too long.
32+
@Test(.enabled(if: false))
33+
func onCPU() async throws {
34+
guard SHLLM.isSupportedDevice else {
35+
Swift.print("⚠️ Metal GPU not available")
36+
return
37+
}
38+
39+
let input: UserInput = .init(messages: [
40+
["role": "system", "content": "You are a helpful assistant."],
41+
["role": "user", "content": "What is the meaning of life?"],
42+
])
43+
44+
try await SHLLM.withDefaultDevice(.cpu) {
45+
guard let llm = try loadModel(
46+
directory: LLM.gemma3_1B,
47+
input: input,
48+
customConfiguration: { config in
49+
var config = config
50+
config.extraEOSTokens = ["<end_of_turn>"]
51+
return config
52+
}
53+
) as LLM<Gemma3TextModel>? else { return }
54+
55+
var response = ""
56+
for try await token in llm.text {
57+
response += token
58+
}
59+
60+
Swift.print(response)
61+
#expect(!response.isEmpty)
62+
}
63+
}
64+
65+
@Test()
66+
func onGPU() async throws {
67+
guard SHLLM.isSupportedDevice else {
68+
Swift.print("⚠️ Metal GPU not available")
69+
return
70+
}
71+
72+
let input: UserInput = .init(messages: [
73+
["role": "system", "content": "You are a helpful assistant."],
74+
["role": "user", "content": "What is the meaning of life?"],
75+
])
76+
77+
try await SHLLM.withDefaultDevice(.gpu) {
78+
guard let llm = try loadModel(
79+
directory: LLM.gemma3_1B,
80+
input: input,
81+
customConfiguration: { config in
82+
var config = config
83+
config.extraEOSTokens = ["<end_of_turn>"]
84+
return config
85+
}
86+
) as LLM<Gemma3TextModel>? else { return }
87+
88+
var response = ""
89+
for try await token in llm.text {
90+
response += token
91+
}
92+
93+
Swift.print(response)
94+
#expect(!response.isEmpty)
95+
}
96+
}

0 commit comments

Comments
 (0)