Skip to content

Commit f9bc2ca

Browse files
committed
Revert "Fix process_weights_after_loading for fp8 dense"
This reverts commit 41abdf1.
1 parent 84a98a2 commit f9bc2ca

File tree

2 files changed

+7
-19
lines changed

2 files changed

+7
-19
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1321,7 +1321,6 @@ def grpo_train(
13211321
print("\n📊 Training Results:")
13221322

13231323
print(f" • Loss: {metrics['loss']:.4f}")
1324-
print(f" • Token Mult Prob Error: {metrics['token_mult_prob_error']:.4f}")
13251324
if master_config["grpo"]["use_dynamic_sampling"]:
13261325
print(f" • Avg Filtered Reward: {np.mean(rewards.numpy()):.4f}")
13271326
print(

nemo_rl/models/generation/fp8.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def load_weights(weights, model_runner):
301301
)
302302
param_scale = torch.squeeze(param_scale, dim=-1)
303303
weights_quantized.append([k, param_lp])
304-
weights_quantized.append([k + "_scale", param_scale])
304+
weights_quantized.append([k + "_scale_inv", param_scale])
305305
# Monkey patch the param class to their subclass, as certain models
306306
# will check the param type to call the proper weightloader
307307
for name, param in model.named_parameters():
@@ -391,10 +391,6 @@ def cast_tensor_to_fp8_blockwise(
391391

392392
def process_weights_after_loading(self, layer) -> None:
393393
from torch.nn import Parameter
394-
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
395-
maybe_post_process_fp8_weight_block,
396-
process_fp8_weight_block_strategy,
397-
)
398394
from vllm.model_executor.parameter import (
399395
BlockQuantScaleParameter,
400396
ModelWeightParameter,
@@ -420,34 +416,27 @@ def _create_param_from_subclass_attributes(custom_param):
420416
param.subclass_type = type(custom_param)
421417
return param
422418

423-
weight_scale = (
424-
layer.weight_scale_inv
425-
if hasattr(layer, "weight_scale_inv")
426-
else layer.weight_scale
427-
)
428-
weight, weight_scale = process_fp8_weight_block_strategy(layer.weight, weight_scale)
419+
weight = layer.weight.data
420+
weight_scale_inv = layer.weight_scale_inv.data
421+
weight = self._maybe_pad_weight(weight)
429422

430423
layer.weight = _create_param_from_subclass_attributes(
431424
ModelWeightParameter(
432-
data=weight.data,
425+
data=weight,
433426
output_dim=0,
434427
input_dim=1,
435428
weight_loader=layer.weight.weight_loader,
436429
)
437430
)
438-
layer.weight_scale = _create_param_from_subclass_attributes(
431+
layer.weight_scale_inv = _create_param_from_subclass_attributes(
439432
BlockQuantScaleParameter(
440-
data=weight_scale.data,
433+
data=weight_scale_inv,
441434
output_dim=0,
442435
input_dim=1,
443436
weight_loader=layer.weight_scale_inv.weight_loader,
444437
)
445438
)
446439

447-
del layer.weight_scale_inv
448-
449-
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
450-
451440

452441
@triton.jit
453442
def _per_token_group_quant_fp8(

0 commit comments

Comments
 (0)