Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions composer/algorithms/gradient_clipping/gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def apply_gradient_clipping(
clipping_type: str,
clipping_threshold: float,
fsdp_enabled: bool,
):
) -> Union[torch.Tensor, None]:
"""Clips all gradients in model based on specified clipping_type.

Args:
Expand All @@ -41,12 +41,16 @@ def apply_gradient_clipping(
threshold by which if grad_norm / weight_norm is greater than this threshold then
scale gradients by this threshold * (weight_norm / grad_norm) (for 'adaptive').
fsdp_enabled (bool): Bool of if the model is a FSDP model or not.

Returns:
Union[torch.Tensor, None]: The total gradient norm before clipping for 'norm' clipping type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just always return the result, not just for norm.

It's a weird contract for a separate function downstream to know this behavior

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downstream function always guards by checking if the clipping_type is norm, so that should be good enough

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it but the other two options don't have top-level scalar values that can be returned:

Hence why I stuck with returning None for those. I agree that the contract is awkward, but we needed to propagate the norm to the logger.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh lol, why do they modify in place for value and not for norm 😆 😓

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved, not ideal, but makes sense that this is the best we can do

None otherwise.
"""
if fsdp_enabled:
for module in model.modules():
if isinstance(module, FullyShardedDataParallel) and module.check_is_root():
if clipping_type == 'norm':
module.clip_grad_norm_(max_norm=clipping_threshold)
return module.clip_grad_norm_(max_norm=clipping_threshold)
elif clipping_type == 'value':
module.clip_grad_norm_(max_norm=clipping_threshold, norm_type=float('inf'))
else:
Expand All @@ -56,12 +60,14 @@ def apply_gradient_clipping(
if clipping_type == 'adaptive':
_apply_agc(parameters, clipping_threshold=clipping_threshold)
elif clipping_type == 'norm':
torch.nn.utils.clip_grad_norm_(parameters, max_norm=clipping_threshold)
return torch.nn.utils.clip_grad_norm_(parameters, max_norm=clipping_threshold)
elif clipping_type == 'value':
torch.nn.utils.clip_grad_value_(parameters, clip_value=clipping_threshold)
else:
raise ValueError(f"clipping_type must be 'adaptive', 'norm', or 'value' not {clipping_type} ")

return None


def _apply_agc(
parameters: Union[torch.Tensor, Iterable[torch.Tensor]],
Expand Down Expand Up @@ -122,24 +128,51 @@ class GradientClipping(Algorithm):
to (for 'value'), what values to clip the gradient norms to (for 'norm'), and
threshold by which if grad_norm / weight_norm is greater than this threshold then
scale gradients by this threshold * (weight_norm / grad_norm) (for 'adaptive').
clipping_frequency_window (int, optional): Number of steps to use for calculating
the rolling average of clipping frequency. Only used for 'norm' clipping type.
Defaults to 100.
"""

def __init__(self, clipping_type: str, clipping_threshold: float):
def __init__(self, clipping_type: str, clipping_threshold: float, clipping_frequency_window: int = 100):
self.clipping_type = clipping_type
self.clipping_threshold = clipping_threshold
self.clipping_frequency_window = clipping_frequency_window
self._clipping_history = []

def match(self, event: Event, state: State) -> bool:
return event in [Event.INIT, Event.AFTER_TRAIN_BATCH]

def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
if event == Event.AFTER_TRAIN_BATCH:
apply_gradient_clipping(
maybe_grad_norm = apply_gradient_clipping(
model=state.model,
clipping_type=self.clipping_type,
clipping_threshold=self.clipping_threshold,
fsdp_enabled=state.fsdp_config_version == 1,
)

if self.clipping_type == 'norm':
if maybe_grad_norm is None:
raise RuntimeError("Expected gradient norm to be returned for 'norm' clipping type, but got None")

grad_norm = maybe_grad_norm.item()

# Log the gradient norm before clipping
logger.log_metrics({'gradient_norm/unclipped_magnitude': grad_norm})

# Log whether clipping was applied
clipping_applied = grad_norm > self.clipping_threshold
logger.log_metrics({'gradient_norm/clipped': float(clipping_applied)})

# Track clipping frequency
self._clipping_history.append(float(clipping_applied))
# Keep only last N steps for frequency calculation
if len(self._clipping_history) > self.clipping_frequency_window:
self._clipping_history.pop(0)

clipping_frequency = sum(self._clipping_history) / len(self._clipping_history)
logger.log_metrics({'gradient_norm/clipping_frequency': clipping_frequency})


def _get_clipped_gradient_coeff(weights: torch.Tensor, grad: torch.Tensor, clipping_threshold: float = 0.01):
"""Clips all gradients in model based on ratio of gradient norms to parameter norms.
Expand Down
Loading