diff --git a/README.md b/README.md index dc7413a..c35c7fe 100644 --- a/README.md +++ b/README.md @@ -283,6 +283,152 @@ if torchada.is_gpu_device(tensor): This is a fundamental limitation because `device.type` is a C-level property that cannot be patched from Python. Downstream projects that check `device.type == "cuda"` need to be patched to use `torchada.is_gpu_device()` or check for both types: `device.type in ("cuda", "musa")`. +## Real-World Integrations + +torchada has been successfully integrated into several popular PyTorch-based projects. Below are examples demonstrating the typical integration patterns. + +### Integrated Projects + +| Project | Category | PR | Status | +|---------|----------|--------|--------| +| [ComfyUI](https://github.com/comfyanonymous/ComfyUI) | Image/Video Generation | [#11618](https://github.com/comfyanonymous/ComfyUI/pull/11618) | Open | +| [LightLLM](https://github.com/ModelTC/LightLLM) | LLM Inference | [#1162](https://github.com/ModelTC/LightLLM/pull/1162) | Open | +| [Xinference](https://github.com/xorbitsai/inference) | Model Serving | [#4425](https://github.com/xorbitsai/inference/pull/4425) | ✅ Merged | +| [LightX2V](https://github.com/ModelTC/LightX2V) | Image/Video Generation | [#678](https://github.com/ModelTC/LightX2V/pull/678) | ✅ Merged | + +### Integration Patterns + +#### Pattern 1: Early Import with Platform Detection + +The most common pattern is to import `torchada` early in the application lifecycle: + +```python +# In __init__.py or main entry point +from your_app.device_utils import is_musa + +if is_musa(): + import torchada # noqa: F401 + +# Platform detection function +def is_musa(): + import torch + return hasattr(torch.version, "musa") and torch.version.musa is not None +``` + +This pattern is used by **LightLLM** and **LightX2V**. + +#### Pattern 2: Add to Dependencies + +Add `torchada` to your project's dependencies: + +```python +# pyproject.toml +dependencies = [ + "torchada>=0.1.11", +] + +# Or requirements.txt +torchada>=0.1.11 +``` + +#### Pattern 3: Device Availability Check + +Create a device availability function that checks for MUSA: + +```python +def is_musa_available() -> bool: + try: + import torch + import torch_musa # noqa: F401 + import torchada # noqa: F401 + return torch.musa.is_available() + except ImportError: + return False + +def get_available_device(): + if torch.cuda.is_available(): + return "cuda" + elif is_musa_available(): + return "musa" + return "cpu" +``` + +This pattern is used by **Xinference**. + +#### Pattern 4: Platform-Specific Feature Flags + +Enable features based on platform capabilities: + +```python +import torchada + +musa_available = hasattr(torch, "musa") and torch.musa.is_available() + +def is_musa(): + return musa_available + +# Enable NVIDIA-like optimizations on MUSA +if is_nvidia() or is_musa(): + ENABLE_PYTORCH_ATTENTION = True + NUM_STREAMS = 2 # Async weight offloading + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.9 +``` + +This pattern is used by **ComfyUI**. + +#### Pattern 5: Platform Device Classes + +For projects with a device abstraction layer: + +```python +from your_platform.base.nvidia import CudaDevice +from your_platform.registry import PLATFORM_DEVICE_REGISTER + +@PLATFORM_DEVICE_REGISTER("musa") +class MusaDevice(CudaDevice): + name = "cuda" # Use CUDA APIs (redirected by torchada) + + @staticmethod + def is_available() -> bool: + try: + import torch + import torchada # noqa: F401 + return hasattr(torch, "musa") and torch.musa.is_available() + except ImportError: + return False +``` + +This pattern is used by **LightX2V**. + +### Common Integration Steps + +1. **Add dependency**: Add `torchada>=0.1.11` to your project dependencies + +2. **Import early**: Import `torchada` before using any `torch.cuda` APIs + ```python + import torchada # Apply patches + import torch + ``` + +3. **Add platform detection**: Create `is_musa()` function for platform-specific code + ```python + def is_musa(): + return hasattr(torch.version, "musa") and torch.version.musa is not None + ``` + +4. **Update feature flags**: Include MUSA in capability checks + ```python + if is_nvidia() or is_musa(): + # Enable GPU-specific features + ``` + +5. **Handle device type checks**: Use `torchada.is_gpu_device()` or check both types + ```python + # Instead of: device.type == "cuda" + # Use: device.type in ("cuda", "musa") + # Or: torchada.is_gpu_device(device) + ``` + ## Architecture torchada uses a decorator-based patch registration system: