Skip to content

Commit

Permalink
Merge branch 'akoumparouli/add_repr_for_parallel_linear' into 'main'
Browse files Browse the repository at this point in the history
Add repr for parallel linear layers

See merge request ADLR/megatron-lm!2358
  • Loading branch information
ericharper committed Jan 26, 2025
2 parents f960d4d + 0a43540 commit d57d110
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
18 changes: 18 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,12 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)

def __repr__(self):
return (
f"{type(self).__name__}(in_features={self.in_features}, "
f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})"
)


class TEColumnParallelLinear(TELinear):
"""
Expand Down Expand Up @@ -464,6 +470,12 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)

def __repr__(self):
return (
f"{type(self).__name__}(in_features={self.in_features}, "
f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})"
)


class TERowParallelLinear(TELinear):
"""
Expand Down Expand Up @@ -542,6 +554,12 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
state_dict, prefix, {'weight': 1}, sharded_offsets
)

def __repr__(self):
return (
f"{type(self).__name__}(in_features={self.in_features}, "
f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})"
)


class TEDotProductAttention(te.pytorch.DotProductAttention):
"""
Expand Down
16 changes: 16 additions & 0 deletions megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,14 @@ def get_extra_state(self) -> None:
"""Keep compatibility with TE state dict."""
return None

def __repr__(self):
tp = self.output_size // self.output_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)


class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
Expand Down Expand Up @@ -1206,3 +1214,11 @@ def set_extra_state(self, state: Any):
def get_extra_state(self) -> None:
"""Keep compatibility with TE state dict."""
return None

def __repr__(self):
tp = self.input_size // self.input_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)

0 comments on commit d57d110

Please sign in to comment.