From 0a4354003528f584840bb8552be7204843c4ed29 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Sat, 25 Jan 2025 17:22:06 -0800 Subject: [PATCH] ADLR/megatron-lm!2358 - Add repr for parallel linear layers --- megatron/core/extensions/transformer_engine.py | 18 ++++++++++++++++++ megatron/core/tensor_parallel/layers.py | 16 ++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 9b7ecf3ffd..f170e68c4e 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -377,6 +377,12 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets ) + def __repr__(self): + return ( + f"{type(self).__name__}(in_features={self.in_features}, " + f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})" + ) + class TEColumnParallelLinear(TELinear): """ @@ -451,6 +457,12 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets ) + def __repr__(self): + return ( + f"{type(self).__name__}(in_features={self.in_features}, " + f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})" + ) + class TERowParallelLinear(TELinear): """ @@ -525,6 +537,12 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): state_dict, prefix, {'weight': 1}, sharded_offsets ) + def __repr__(self): + return ( + f"{type(self).__name__}(in_features={self.in_features}, " + f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})" + ) + class TEDotProductAttention(te.pytorch.DotProductAttention): """ diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 12d2be69a9..a792ef9bea 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -989,6 +989,14 @@ def get_extra_state(self) -> None: """Keep compatibility with TE state dict.""" return None + def __repr__(self): + tp = self.output_size // self.output_size_per_partition + use_bias = self.bias is not None and self.bias is True + return ( + f"{type(self).__name__}(in_features={self.input_size}, " + f"out_features={self.output_size}, bias={use_bias}, TP={tp})" + ) + class RowParallelLinear(torch.nn.Module): """Linear layer with row parallelism. @@ -1213,3 +1221,11 @@ def set_extra_state(self, state: Any): def get_extra_state(self) -> None: """Keep compatibility with TE state dict.""" return None + + def __repr__(self): + tp = self.input_size // self.input_size_per_partition + use_bias = self.bias is not None and self.bias is True + return ( + f"{type(self).__name__}(in_features={self.input_size}, " + f"out_features={self.output_size}, bias={use_bias}, TP={tp})" + )