Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,152 @@ if torchada.is_gpu_device(tensor):

This is a fundamental limitation because `device.type` is a C-level property that cannot be patched from Python. Downstream projects that check `device.type == "cuda"` need to be patched to use `torchada.is_gpu_device()` or check for both types: `device.type in ("cuda", "musa")`.

## Real-World Integrations

torchada has been successfully integrated into several popular PyTorch-based projects. Below are examples demonstrating the typical integration patterns.

### Integrated Projects

| Project | Category | PR | Status |
|---------|----------|--------|--------|
| [ComfyUI](https://github.com/comfyanonymous/ComfyUI) | Image/Video Generation | [#11618](https://github.com/comfyanonymous/ComfyUI/pull/11618) | Open |
| [LightLLM](https://github.com/ModelTC/LightLLM) | LLM Inference | [#1162](https://github.com/ModelTC/LightLLM/pull/1162) | Open |
| [Xinference](https://github.com/xorbitsai/inference) | Model Serving | [#4425](https://github.com/xorbitsai/inference/pull/4425) | ✅ Merged |
| [LightX2V](https://github.com/ModelTC/LightX2V) | Image/Video Generation | [#678](https://github.com/ModelTC/LightX2V/pull/678) | ✅ Merged |

### Integration Patterns

#### Pattern 1: Early Import with Platform Detection

The most common pattern is to import `torchada` early in the application lifecycle:

```python
# In __init__.py or main entry point
from your_app.device_utils import is_musa

if is_musa():
import torchada # noqa: F401

# Platform detection function
def is_musa():
import torch
return hasattr(torch.version, "musa") and torch.version.musa is not None
```

This pattern is used by **LightLLM** and **LightX2V**.

#### Pattern 2: Add to Dependencies

Add `torchada` to your project's dependencies:

```python
# pyproject.toml
dependencies = [
"torchada>=0.1.11",
]

# Or requirements.txt
torchada>=0.1.11
```

#### Pattern 3: Device Availability Check

Create a device availability function that checks for MUSA:

```python
def is_musa_available() -> bool:
try:
import torch
import torch_musa # noqa: F401
import torchada # noqa: F401
return torch.musa.is_available()
except ImportError:
return False

def get_available_device():
if torch.cuda.is_available():
return "cuda"
elif is_musa_available():
return "musa"
return "cpu"
```

This pattern is used by **Xinference**.

#### Pattern 4: Platform-Specific Feature Flags

Enable features based on platform capabilities:

```python
import torchada

musa_available = hasattr(torch, "musa") and torch.musa.is_available()

def is_musa():
return musa_available

# Enable NVIDIA-like optimizations on MUSA
if is_nvidia() or is_musa():
ENABLE_PYTORCH_ATTENTION = True
NUM_STREAMS = 2 # Async weight offloading
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.9
```

This pattern is used by **ComfyUI**.

#### Pattern 5: Platform Device Classes

For projects with a device abstraction layer:

```python
from your_platform.base.nvidia import CudaDevice
from your_platform.registry import PLATFORM_DEVICE_REGISTER

@PLATFORM_DEVICE_REGISTER("musa")
class MusaDevice(CudaDevice):
name = "cuda" # Use CUDA APIs (redirected by torchada)

@staticmethod
def is_available() -> bool:
try:
import torch
import torchada # noqa: F401
return hasattr(torch, "musa") and torch.musa.is_available()
except ImportError:
return False
```

This pattern is used by **LightX2V**.

### Common Integration Steps

1. **Add dependency**: Add `torchada>=0.1.11` to your project dependencies

2. **Import early**: Import `torchada` before using any `torch.cuda` APIs
```python
import torchada # Apply patches
import torch
```

3. **Add platform detection**: Create `is_musa()` function for platform-specific code
```python
def is_musa():
return hasattr(torch.version, "musa") and torch.version.musa is not None
```

4. **Update feature flags**: Include MUSA in capability checks
```python
if is_nvidia() or is_musa():
# Enable GPU-specific features
```

5. **Handle device type checks**: Use `torchada.is_gpu_device()` or check both types
```python
# Instead of: device.type == "cuda"
# Use: device.type in ("cuda", "musa")
# Or: torchada.is_gpu_device(device)
```

## Architecture

torchada uses a decorator-based patch registration system:
Expand Down