feat: add full MPS (Apple Silicon) support#288
Open
gruckion wants to merge 2 commits intonari-labs:mainfrom
Open
feat: add full MPS (Apple Silicon) support#288gruckion wants to merge 2 commits intonari-labs:mainfrom
gruckion wants to merge 2 commits intonari-labs:mainfrom
Conversation
This commit enables true MPS GPU acceleration for macOS users: - Guard cudagraph_mark_step_begin() with CUDA device check (dia/model.py) - Skip torch.compile on non-CUDA devices with warning (dia/model.py) - Add MPS to CLI device auto-detection (cli.py) - Fix simple-mac.py to use float32 for MPS stability (example/simple-mac.py) - Guard Triton config with platform check (example/benchmark.py) - Add MPS seed management to set_seed() functions (cli.py, app.py) - Add ARM64 Linux Dockerfile for Graviton/ARM servers (docker/Dockerfile.arm) Tested on Apple Silicon with ~0.16x realtime performance. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR enables true MPS GPU acceleration for macOS Apple Silicon users, providing a 44x speedup over CPU fallback.
Performance Benchmarks
Background
PR #167 added
example/simple-mac.pyas a Mac support solution, but it was actually a CPU fallback workaround that avoided MPS issues rather than fixing them. This PR addresses the root causes to enable true MPS GPU acceleration.Changes
1. Guard CUDAGraph Call (Critical)
File:
dia/model.py:701-702Why:
cudagraph_mark_step_begin()is CUDA-specific and has no MPS equivalent. Called ~860 times per generation in the autoregressive loop.2. Add Device Check to torch.compile (Critical)
File:
dia/model.py:657-667Why:
mode="max-autotune"requires Triton which is CUDA/ROCm only. MPS torch.compile is still experimental (PyTorch #150121).3. Fix CLI Device Auto-Detection
File:
cli.py:12-18, 82-83Why: Mac users running
python cli.py "text" --output out.wavnow automatically get MPS acceleration instead of slow CPU fallback.4. Fix simple-mac.py dtype
File:
example/simple-mac.py:4-6Why: Matches
app.pyrecommendation. MPS float16 has documented precision bugs and provides minimal speedup on Apple Silicon (no dedicated Tensor Cores).5. Guard Triton Config in Benchmark
File:
example/benchmark.py:9-13Why: Triton is only available on Linux/Windows per
pyproject.toml. These settings fail on macOS.6. Add MPS Seed Management
Files:
cli.py:30-32,app.py:66-68Why: Explicit MPS RNG seeding for reproducibility (though MPS reproducibility has known limitations).
7. Add ARM64 Linux Dockerfile
File:
docker/Dockerfile.arm(new)Adds support for ARM64 Linux servers (AWS Graviton, Ampere Altra, etc.) with CPU-only inference.
Note: Docker cannot access MPS on macOS (runs in Linux VM). macOS users must install natively for GPU acceleration.
Test Plan
python example/simple-mac.pyruns without errors on MPSpython cli.py "[S1] Hello world." --output test.wavauto-detects MPSReferences
🤖 Generated with Claude Code