Skip to content

Commit

Permalink
ADLR/megatron-lm!2358 - Add repr for parallel linear layers
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumpa authored and ericharper committed Jan 26, 2025
1 parent ddd920f commit 0a43540
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
18 changes: 18 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,12 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)

def __repr__(self):
return (
f"{type(self).__name__}(in_features={self.in_features}, "
f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})"
)


class TEColumnParallelLinear(TELinear):
"""
Expand Down Expand Up @@ -451,6 +457,12 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)

def __repr__(self):
return (
f"{type(self).__name__}(in_features={self.in_features}, "
f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})"
)


class TERowParallelLinear(TELinear):
"""
Expand Down Expand Up @@ -525,6 +537,12 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
state_dict, prefix, {'weight': 1}, sharded_offsets
)

def __repr__(self):
return (
f"{type(self).__name__}(in_features={self.in_features}, "
f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})"
)


class TEDotProductAttention(te.pytorch.DotProductAttention):
"""
Expand Down
16 changes: 16 additions & 0 deletions megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,14 @@ def get_extra_state(self) -> None:
"""Keep compatibility with TE state dict."""
return None

def __repr__(self):
tp = self.output_size // self.output_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)


class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
Expand Down Expand Up @@ -1213,3 +1221,11 @@ def set_extra_state(self, state: Any):
def get_extra_state(self) -> None:
"""Keep compatibility with TE state dict."""
return None

def __repr__(self):
tp = self.input_size // self.input_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)

0 comments on commit 0a43540

Please sign in to comment.