diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index e74f47803e..f5e5ca5f5c 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -386,6 +386,12 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets ) + def __repr__(self): + return ( + f"{type(self).__name__}(in_features={self.in_features}, " + f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})" + ) + class TEColumnParallelLinear(TELinear): """ @@ -464,6 +470,12 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets ) + def __repr__(self): + return ( + f"{type(self).__name__}(in_features={self.in_features}, " + f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})" + ) + class TERowParallelLinear(TELinear): """ @@ -542,6 +554,12 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): state_dict, prefix, {'weight': 1}, sharded_offsets ) + def __repr__(self): + return ( + f"{type(self).__name__}(in_features={self.in_features}, " + f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})" + ) + class TEDotProductAttention(te.pytorch.DotProductAttention): """ diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index fde8c106f1..9555041865 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -985,6 +985,14 @@ def get_extra_state(self) -> None: """Keep compatibility with TE state dict.""" return None + def __repr__(self): + tp = self.output_size // self.output_size_per_partition + use_bias = self.bias is not None and self.bias is True + return ( + f"{type(self).__name__}(in_features={self.input_size}, " + f"out_features={self.output_size}, bias={use_bias}, TP={tp})" + ) + class RowParallelLinear(torch.nn.Module): """Linear layer with row parallelism. @@ -1206,3 +1214,11 @@ def set_extra_state(self, state: Any): def get_extra_state(self) -> None: """Keep compatibility with TE state dict.""" return None + + def __repr__(self): + tp = self.input_size // self.input_size_per_partition + use_bias = self.bias is not None and self.bias is True + return ( + f"{type(self).__name__}(in_features={self.input_size}, " + f"out_features={self.output_size}, bias={use_bias}, TP={tp})" + )