diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py index 98d3ca18e9..0ec8e6cd19 100644 --- a/examples/distributed_inference/tensor_parallel_initialize_dist.py +++ b/examples/distributed_inference/tensor_parallel_initialize_dist.py @@ -3,7 +3,7 @@ Tensor Parallel Initialize Distributed Environment ================================================== -This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. +This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. These utilities are useful for tensor parallel distributed inference examples using torch.distributed. """ import logging @@ -14,32 +14,68 @@ import tensorrt as trt import torch import torch.distributed as dist -from torch.distributed._tensor.device_mesh import init_device_mesh +from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh -def find_repo_root(max_depth=10): - dir_path = os.path.dirname(os.path.realpath(__file__)) - for i in range(max_depth): - files = os.listdir(dir_path) - if "MODULE.bazel" in files: - return dir_path - else: - dir_path = os.path.dirname(dir_path) +def initialize_logger( + rank, logger_file_name, file_level=logging.DEBUG, console_level=logging.INFO +): + """Initialize rank-specific Torch-TensorRT logger with configurable handler levels. - raise RuntimeError("Could not find repo root") + Logger level is set to DEBUG (pass-through), handlers control filtering for files and stream buffers + Args: + rank: Process rank for multi-GPU + logger_file_name: Base name for log file (will add _rank.log) + file_level: What goes to file - default DEBUG (everything) + console_level: What prints to console - default INFO (clean output) + """ + logger = logging.getLogger("torch_tensorrt") + logger.setLevel(logging.DEBUG) + logger.handlers.clear() -def initialize_logger(rank, logger_file_name): - logger = logging.getLogger() - logger.setLevel(logging.INFO) + # File handler fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") - fh.setLevel(logging.INFO) + fh.setLevel(file_level) + fh.setFormatter( + logging.Formatter( + f"[Rank {rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ) logger.addHandler(fh) + + # console handler + ch = logging.StreamHandler() + ch.setLevel( + console_level + ) # Console handler controls what's printed in console output + ch.setFormatter(logging.Formatter(f"[Rank {rank}] %(levelname)s: %(message)s")) + logger.addHandler(ch) + + # safegauard though not reqd + logger.propagate = False return logger # This is required for env initialization since we use mpirun -def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): +def initialize_distributed_env( + logger_file_name, + rank=0, + world_size=1, + port=29500, + file_level="debug", + console_level="info", +): + """Initialize distributed environment with handler-based logging. + + Args: + logger_file_name: Base name for log files + rank: Initial rank (overridden by OMPI env vars) + world_size: Initial world size (overridden by OMPI env vars) + port: Master port for distributed communication + file_level: File handler level - "debug", "info", "warning" (default: "debug") + console_level: Console handler level - "debug", "info", "warning" (default: "info") + """ local_rank = int( os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) ) @@ -50,9 +86,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950 os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["TRTLLM_PLUGINS_PATH"] = ( - find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so" - ) # Necessary to assign a device to each rank. torch.cuda.set_device(local_rank) @@ -66,12 +99,39 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950 device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) rank = device_mesh.get_rank() assert rank == local_rank - logger = initialize_logger(rank, logger_file_name) + # Convert string handler levels to logging constants + level_map = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + } + file_level_int = level_map.get(file_level.lower(), logging.DEBUG) + console_level_int = level_map.get(console_level.lower(), logging.INFO) + + # Initialize logger with handler-specific levels + # Logger itself is always DEBUG - handlers do the filtering + logger = initialize_logger( + rank, + logger_file_name, + file_level=file_level_int, + console_level=console_level_int, + ) device_id = ( rank % torch.cuda.device_count() ) # Ensure each rank gets a unique device torch.cuda.set_device(device_id) + # Set C++ TensorRT runtime log level based on most verbose handler + # Use the most verbose level to ensure all important logs are captured + cpp_level = min(file_level_int, console_level_int) + try: + import torch_tensorrt.logging as torchtrt_logging + + torchtrt_logging.set_level(cpp_level) + except Exception as e: + logger.warning(f"Could not set C++ TensorRT log level: {e}") + return device_mesh, world_size, rank, logger @@ -79,3 +139,28 @@ def cleanup_distributed_env(): """Clean up distributed process group to prevent resource leaks.""" if dist.is_initialized(): dist.destroy_process_group() + + +def check_tensor_parallel_device_number(world_size: int) -> None: + if world_size % 2 != 0: + raise ValueError( + f"TP examples require even number of GPUs, but got {world_size} gpus" + ) + + +def get_tensor_parallel_device_mesh( + rank: int = 0, world_size: int = 1 +) -> tuple[DeviceMesh, int, int]: + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) + rank = device_mesh.get_rank() + assert rank == local_rank + device_id = ( + rank % torch.cuda.device_count() + ) # Ensure each rank gets a unique device + torch.cuda.set_device(device_id) + + return device_mesh, world_size, rank diff --git a/examples/distributed_inference/tensor_parallel_rotary_embedding.py b/examples/distributed_inference/tensor_parallel_rotary_embedding.py index da3f3fd8fd..bd1216bc75 100644 --- a/examples/distributed_inference/tensor_parallel_rotary_embedding.py +++ b/examples/distributed_inference/tensor_parallel_rotary_embedding.py @@ -9,20 +9,19 @@ """ -import logging -import os import time import torch -import torch_tensorrt from rotary_embedding import RotaryAttention, parallel_rotary_block from tensor_parallel_initialize_dist import ( cleanup_distributed_env, initialize_distributed_env, ) +# Initialize distributed environment and logger BEFORE importing torch_tensorrt +# This ensures logging is configured before any import-time log messages device_mesh, _world_size, _rank, logger = initialize_distributed_env( - "./tensor_parallel_rotary_embedding" + "tensor_parallel_rotary_embedding" ) diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index c5688c6e5b..f2dc6861cb 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -1,7 +1,7 @@ """ .. _tensor_parallel_simple_example: -Torch Parallel Distributed example for simple model +Tensor Parallel Distributed Inference with Torch-TensorRT ========================================= Below example shows how to use Torch-TensorRT backend for distributed inference with tensor parallelism. @@ -25,11 +25,18 @@ import torch import torch.distributed as dist import torch.nn as nn -import torch_tensorrt from tensor_parallel_initialize_dist import ( cleanup_distributed_env, + get_tensor_parallel_device_mesh, initialize_distributed_env, ) + +# Initialize distributed environment and logger BEFORE importing torch_tensorrt +# This ensures logging is configured before any import-time log messages +device_mesh, _world_size, _rank, logger = initialize_distributed_env( + "tensor_parallel_simple_example" +) + from torch.distributed._tensor import Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -37,10 +44,6 @@ parallelize_module, ) -device_mesh, _world_size, _rank, logger = initialize_distributed_env( - "./tensor_parallel_simple_example" -) - """ This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py """ diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index 3c4df0fbde..586153cf15 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -76,10 +76,29 @@ def _enabled_features_str() -> str: enabled = lambda x: "ENABLED" if x else "DISABLED" - out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n" # type: ignore[no-untyped-call] + out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n - TensorRT-LLM for NCCL: {enabled(_TRTLLM_AVAIL)}\n" # type: ignore[no-untyped-call] return out_str +# Inline helper functions for checking feature availability +def has_torch_tensorrt_runtime() -> bool: + """Check if Torch-TensorRT C++ runtime is available. + + Returns: + bool: True if libtorchtrt_runtime.so or libtorchtrt.so is available + """ + return bool(ENABLED_FEATURES.torch_tensorrt_runtime) + + +def has_torchscript_frontend() -> bool: + """Check if TorchScript frontend is available. + + Returns: + bool: True if libtorchtrt.so is available + """ + return bool(ENABLED_FEATURES.torchscript_frontend) + + def needs_tensorrt_rtx(f: Callable[..., Any]) -> Callable[..., Any]: def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: if ENABLED_FEATURES.tensorrt_rtx: @@ -165,6 +184,19 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: def needs_trtllm_for_nccl(f: Callable[..., Any]) -> Callable[..., Any]: + """ + Runtime check decorator for TensorRT-LLM NCCL plugin availability. + + WARNING: This decorator CANNOT prevent registration of converters at import time. + When used with @dynamo_tensorrt_converter, the converter is always registered + regardless of decorator order, because registration happens at import time before + the wrapper is called. + + This decorator is kept for potential non-registration use cases where + runtime checks are appropriate. + @apbose: to discuss if this is required + """ + def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: if ENABLED_FEATURES.trtllm_for_nccl: return f(*args, **kwargs) @@ -172,7 +204,7 @@ def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: raise NotImplementedError( - "Refit feature is currently not available in Python 3.13 or higher" + "TensorRT-LLM plugin for NCCL is not available" ) return not_implemented(*args, **kwargs) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 59783de665..299fa59947 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -7,6 +7,7 @@ import torch import torch._dynamo as td +import torch_tensorrt.logging as torchtrt_logging from torch._dynamo.backends.common import aot_autograd from torch._dynamo.utils import detect_fake_mode from torch._functorch.aot_autograd import aot_export_joint_simple @@ -23,7 +24,6 @@ from torch_tensorrt.dynamo.utils import ( parse_dynamo_kwargs, prepare_inputs, - set_log_level, ) logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ def torch_tensorrt_backend( and "debug" in kwargs["options"] and kwargs["options"]["debug"] ) or ("debug" in kwargs and kwargs["debug"]): - set_log_level(logger.parent, logging.DEBUG) + torchtrt_logging.set_level(logging.DEBUG) DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index db14e3528b..302a254f60 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -5,7 +5,7 @@ import tensorrt as trt from torch.fx.node import Argument, Target -from torch_tensorrt._features import needs_trtllm_for_nccl +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -20,37 +20,53 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -@needs_trtllm_for_nccl -@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) -def fused_nccl_gather( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.nccl_ops.nccl_gather( - ctx, - target, - SourceIR.ATEN, - name, - [args[0]], +# Conditionally register NCCL converters only if TensorRT-LLM plugin is available. +# We use an `if` statement instead of @needs_trtllm_for_nccl decorator because +# @dynamo_tensorrt_converter ALWAYS registers at import time regardless of decorator +# order. Conditional registration prevents registration when TRTLLM is unavailable, +# allowing fallback to PyTorch execution for NCCL ops. + +# Order 1: @needs_trtllm_for_nccl followed by registering the converter leads to plugin registry not finding nccl ops plugins since we register the bare converter, without the decorator +# Order 2: registering the converter first followed by @needs_trtllm_for_nccl leads to "NotImplementedError: TensorRT-LLM plugin for NCCL is not available :TensorRT-LLM plugin for NCCL is not available" and no fall back to pytorch +if ENABLED_FEATURES.trtllm_for_nccl: + _LOGGER.debug( + "TensorRT-LLM plugin for NCCL is available. Registering NCCL converters." ) + @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) + def fused_nccl_gather( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + return impl.nccl_ops.nccl_gather( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + + @dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op) + def fused_nccl_reduce_scatter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + return impl.nccl_ops.nccl_reduce_scatter( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) -@needs_trtllm_for_nccl -@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op) -def fused_nccl_reduce_scatter( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.nccl_ops.nccl_reduce_scatter( - ctx, - target, - SourceIR.ATEN, - name, - [args[0]], +else: + _LOGGER.info( + "TensorRT-LLM plugin for NCCL is not available. " + "NCCL operations will fall back to PyTorch execution." ) diff --git a/py/torch_tensorrt/logging.py b/py/torch_tensorrt/logging.py index 0cba3bd510..80eafbd909 100644 --- a/py/torch_tensorrt/logging.py +++ b/py/torch_tensorrt/logging.py @@ -3,7 +3,10 @@ import tensorrt as trt import torch -from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt._features import ( + has_torch_tensorrt_runtime, + has_torchscript_frontend, +) logging.captureWarnings(True) _LOGGER = logging.getLogger("torch_tensorrt [TensorRT Conversion Context]") @@ -31,6 +34,61 @@ def log(self, severity: trt.ILogger.Severity, msg: str) -> None: TRT_LOGGER = _TRTLogger() +def set_level(level: int, logger: Any = None) -> None: + """Set log level for both Python and C++ torch_tensorrt loggers. + + Permanently sets the log level until changed again or process exits. + Automatically handles runtime availability checks. + + This sets the log level for: + - Specified Python logger (or root torch_tensorrt logger if None) + - TorchScript frontend C++ logger (if available) + - Dynamo runtime C++ logger (if available) + + Args: + level: Python logging level (logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL) + logger: Optional logger to set level for. If None, sets the root torch_tensorrt logger. + + Example: + + # Set debug logging for entire session + torch_tensorrt.logging.set_level(logging.DEBUG) + + # Or set for a specific logger + my_logger = logging.getLogger("torch_tensorrt.dynamo") + torch_tensorrt.logging.set_level(logging.DEBUG, logger=my_logger) + """ + # Set the specified logger or default to root torch_tensorrt logger + if logger is None: + logging.getLogger("torch_tensorrt").setLevel(level) + _LOGGER.setLevel(level) + else: + logger.setLevel(level) + + # runtime set log level + if has_torch_tensorrt_runtime(): + if level == logging.CRITICAL: + torch.ops.tensorrt.set_logging_level( + int(trt.ILogger.Severity.INTERNAL_ERROR) + ) + elif level == logging.ERROR: + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.ERROR)) + elif level == logging.WARNING: + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.WARNING)) + elif level == logging.INFO: + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.INFO)) + elif level == logging.DEBUG: + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE)) + elif level == logging.NOTSET: + # Graph level (most verbose) + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE) + 1) + else: + raise ValueError( + f"Invalid log level: {level}. Must be one of: " + f"logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL, logging.NOTSET" + ) + + class internal_errors: """Context-manager to limit displayed log messages to just internal errors @@ -46,13 +104,13 @@ def __enter__(self) -> None: self.external_lvl = _LOGGER.getEffectiveLevel() _LOGGER.setLevel(logging.CRITICAL) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.InternalError) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): self.rt_level = torch.ops.tensorrt.get_logging_level() torch.ops.tensorrt.set_logging_level( int(trt.ILogger.Severity.INTERNAL_ERROR) @@ -61,12 +119,12 @@ def __enter__(self) -> None: def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging ts_logging.set_reportable_log_level(self.ts_level) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): torch.ops.tensorrt.set_logging_level(self.rt_level) @@ -85,25 +143,25 @@ def __enter__(self) -> None: self.external_lvl = _LOGGER.getEffectiveLevel() _LOGGER.setLevel(logging.ERROR) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Error) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): self.rt_level = torch.ops.tensorrt.get_logging_level() torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.ERROR)) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging ts_logging.set_reportable_log_level(self.ts_level) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): torch.ops.tensorrt.set_logging_level(self.rt_level) @@ -122,25 +180,25 @@ def __enter__(self) -> None: self.external_lvl = _LOGGER.getEffectiveLevel() _LOGGER.setLevel(logging.WARNING) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Warning) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): self.rt_level = torch.ops.tensorrt.get_logging_level() torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.WARNING)) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging ts_logging.set_reportable_log_level(self.ts_level) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): torch.ops.tensorrt.set_logging_level(self.rt_level) @@ -159,25 +217,25 @@ def __enter__(self) -> None: self.external_lvl = _LOGGER.getEffectiveLevel() _LOGGER.setLevel(logging.INFO) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Info) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): self.rt_level = torch.ops.tensorrt.get_logging_level() torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.INFO)) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging ts_logging.set_reportable_log_level(self.ts_level) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): torch.ops.tensorrt.set_logging_level(self.rt_level) @@ -196,25 +254,25 @@ def __enter__(self) -> None: self.external_lvl = _LOGGER.getEffectiveLevel() _LOGGER.setLevel(logging.DEBUG) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Debug) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): self.rt_level = torch.ops.tensorrt.get_logging_level() torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE)) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging ts_logging.set_reportable_log_level(self.ts_level) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): torch.ops.tensorrt.set_logging_level(self.rt_level) @@ -234,23 +292,23 @@ def __enter__(self) -> None: self.external_lvl = _LOGGER.getEffectiveLevel() _LOGGER.setLevel(logging.NOTSET) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Graph) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): self.rt_level = torch.ops.tensorrt.get_logging_level() torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE) + 1) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging ts_logging.set_reportable_log_level(self.ts_level) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): torch.ops.tensorrt.set_logging_level(self.rt_level) diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index eafe16d455..c2fff2027e 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -8,7 +8,25 @@ from distributed_utils import set_environment_variables_pytest from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt._utils import is_platform_supported_for_trtllm +from torch_tensorrt._features import ENABLED_FEATURES + + +def is_distributed_nccl_available(): + """ + Check if torch.distributed with NCCL backend is available. + + Note: torch.distributed is available on Windows but NCCL backend is not. + NCCL (NVIDIA Collective Communications Library) is Linux/Unix only. + This function returns False on Windows, Jetson, and other platforms + where NCCL backend is not supported. + """ + try: + import torch.distributed as dist + + # Check if NCCL backend is available (False on Windows, since its gloo. For ORIN some torch distribution it is available + return dist.is_nccl_available() + except (ImportError, AttributeError): + return False class DistributedGatherModel(nn.Module): @@ -42,9 +60,15 @@ def forward(self, x): class TestNcclOpsConverter(DispatchTestCase): + # 1. Skip if NCCL backend is not available (e.g., Windows, Jetson) - hard requirement + # 2. Skip if TRTLLM is unavailable (e.g., CUDA 13) - no converters registered + @unittest.skipIf( + not is_distributed_nccl_available(), + "Skipped: NCCL backend is not available (Windows/Jetson Orin not supported).", + ) @unittest.skipIf( - not is_platform_supported_for_trtllm(), - "Skipped on Windows, Jetson and CUDA13: NCCL backend is not supported.", + not ENABLED_FEATURES.trtllm_for_nccl, + "Skipped: TensorRT-LLM plugin for NCCL is not available (e.g., CUDA 13).", ) @classmethod def setUpClass(cls):