Skip to content

Commit

Permalink
2024-12-06 nightly release (f450c59)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Dec 6, 2024
1 parent 4b96c2e commit a94f6e2
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 84 deletions.
30 changes: 10 additions & 20 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,6 @@ def purge(self) -> None:


class CommOpGradientScaling(torch.autograd.Function):
# user override: inline autograd.Function is safe to trace since only tensor mutations / no global state
_compiled_autograd_should_lift = False

@staticmethod
# pyre-ignore
def forward(
Expand Down Expand Up @@ -501,23 +498,16 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool:
"If you don't turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n"
)
if hasattr(emb_op.emb_module, "prefetch"):
if isinstance(emb_op.emb_module, SSDTableBatchedEmbeddingBags):
emb_op.emb_module.prefetch(
indices=features.values(),
offsets=features.offsets(),
forward_stream=forward_stream,
)
else:
emb_op.emb_module.prefetch(
indices=features.values(),
offsets=features.offsets(),
forward_stream=forward_stream,
batch_size_per_feature_per_rank=(
features.stride_per_key_per_rank()
if features.variable_stride_per_key()
else None
),
)
emb_op.emb_module.prefetch(
indices=features.values(),
offsets=features.offsets(),
forward_stream=forward_stream,
batch_size_per_feature_per_rank=(
features.stride_per_key_per_rank()
if features.variable_stride_per_key()
else None
),
)

def _merge_variable_batch_embeddings(
self, embeddings: List[torch.Tensor], splits: List[List[int]]
Expand Down
20 changes: 19 additions & 1 deletion torchrec/distributed/fused_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, List, Optional

import torch

Expand All @@ -24,6 +24,10 @@
FUSED_PARAM_TBE_ROW_ALIGNMENT: str = "__register_tbe_row_alignment"
FUSED_PARAM_BOUNDS_CHECK_MODE: str = "__register_tbe_bounds_check_mode"

# Force lengths to offsets conversion before TBE lookup. Helps with performance
# with certain ways to split models.
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: str = "__register_lengths_to_offsets_lookup"


class TBEToRegisterMixIn:
def get_tbes_to_register(
Expand Down Expand Up @@ -68,6 +72,18 @@ def fused_param_bounds_check_mode(
return fused_params[FUSED_PARAM_BOUNDS_CHECK_MODE]


def fused_param_lengths_to_offsets_lookup(
fused_params: Optional[Dict[str, Any]]
) -> bool:
if (
fused_params is None
or FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP not in fused_params
):
return False
else:
return fused_params[FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP]


def is_fused_param_quant_state_dict_split_scale_bias(
fused_params: Optional[Dict[str, Any]]
) -> bool:
Expand All @@ -93,5 +109,7 @@ def tbe_fused_params(
fused_params_for_tbe.pop(FUSED_PARAM_TBE_ROW_ALIGNMENT)
if FUSED_PARAM_BOUNDS_CHECK_MODE in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_BOUNDS_CHECK_MODE)
if FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP)

return fused_params_for_tbe
139 changes: 92 additions & 47 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from torchrec.distributed.fused_params import (
fused_param_bounds_check_mode,
fused_param_lengths_to_offsets_lookup,
is_fused_param_quant_state_dict_split_scale_bias,
is_fused_param_register_tbe,
tbe_fused_params,
Expand Down Expand Up @@ -171,6 +172,19 @@ def _unwrap_kjt_for_cpu(
return indices, offsets, None


@torch.fx.wrap
def _unwrap_kjt_lengths(
features: KeyedJaggedTensor,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
indices = features.values()
lengths = features.lengths()
return (
indices.int(),
lengths.int(),
features.weights_or_none(),
)


@torch.fx.wrap
def _unwrap_optional_tensor(
tensor: Optional[torch.Tensor],
Expand All @@ -180,6 +194,26 @@ def _unwrap_optional_tensor(
return tensor


class IntNBitTableBatchedEmbeddingBagsCodegenWithLength(
IntNBitTableBatchedEmbeddingBagsCodegen
):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

# pyre-ignore Inconsistent override [14]
def forward(
self,
indices: torch.Tensor,
lengths: torch.Tensor,
per_sample_weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self._forward_impl(
indices=indices,
offsets=(torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)),
per_sample_weights=per_sample_weights,
)


class QuantBatchedEmbeddingBag(
BaseBatchedEmbeddingBag[
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
Expand Down Expand Up @@ -237,22 +271,27 @@ def __init__(
)
)

self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (
IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=embedding_specs,
device=device,
pooling_mode=self._pooling,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
bounds_check_mode=(
bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING
),
feature_names_per_table=[
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
)
self.lengths_to_tbe: bool = fused_param_lengths_to_offsets_lookup(fused_params)

if self.lengths_to_tbe:
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength
else:
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen

self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz(
embedding_specs=embedding_specs,
device=device,
pooling_mode=self._pooling,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
bounds_check_mode=(
bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING
),
feature_names_per_table=[
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
)
if device is not None:
self._emb_module.initialize_weights()
Expand All @@ -271,44 +310,50 @@ def get_tbes_to_register(
) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]:
return {self._emb_module: self._config}

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
# Important: _unwrap_kjt regex for FX tracing TAGing
if self._runtime_device.type == "cpu":
indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu(
features, self._config.is_weighted
)
def _emb_module_forward(
self,
indices: torch.Tensor,
lengths_or_offsets: torch.Tensor,
weights: Optional[torch.Tensor],
) -> torch.Tensor:
kwargs = {"indices": indices}

if self.lengths_to_tbe:
kwargs["lengths"] = lengths_or_offsets
else:
indices, offsets, per_sample_weights = _unwrap_kjt(features)
kwargs["offsets"] = lengths_or_offsets

if self._is_weighted:
weights = _unwrap_optional_tensor(per_sample_weights)
if self._emb_module_registered:
# Conditional call of .forward function for FX:
# emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module)
# emb_module.forward() does not require registering emb_module in named_modules (FX node call_function)
# For some post processing that requires TBE emb_module copied in fx.GraphModule we need to be call_module, as it will copies this module inside fx.GraphModule unchanged.
return self.emb_module(
indices=indices,
offsets=offsets,
per_sample_weights=weights,
)
kwargs["per_sample_weights"] = _unwrap_optional_tensor(weights)

if self._emb_module_registered:
# Conditional call of .forward function for FX:
# emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module)
# emb_module.forward() does not require registering emb_module in named_modules (FX node call_function)
# For some post processing that requires TBE emb_module copied in fx.GraphModule we need to be call_module, as it will copies this module inside fx.GraphModule unchanged.
return self._emb_module(**kwargs)
else:
return self._emb_module.forward(**kwargs)

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
# Important: _unwrap_kjt regex for FX tracing TAGing
lengths, offsets = None, None
if self._runtime_device.type == "cpu":
if self.lengths_to_tbe:
indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features)
else:
return self.emb_module.forward(
indices=indices,
offsets=offsets,
per_sample_weights=weights,
indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu(
features, self._config.is_weighted
)
else:
if self._emb_module_registered:
return self.emb_module(
indices=indices,
offsets=offsets,
)
if self.lengths_to_tbe:
indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features)
else:
return self.emb_module.forward(
indices=indices,
offsets=offsets,
)
indices, offsets, per_sample_weights = _unwrap_kjt(features)

return self._emb_module_forward(
indices, lengths if lengths is not None else offsets, per_sample_weights
)

def named_buffers(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
Expand Down
39 changes: 26 additions & 13 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,10 @@ def _test_kjt_input_module(
# Need to set as size in order to run a proper forward
em_inputs[0][0] = kjt.values().size(0)
em_inputs[2][0] = kjt.weights().size(0)
eager_output = symint_wrapper(*em_inputs)

if not kjt.values().is_meta:
eager_output = symint_wrapper(*em_inputs)

pt2_ir = torch.export.export(
symint_wrapper, em_inputs, {}, strict=False
)
Expand Down Expand Up @@ -504,6 +507,28 @@ def forward(self, kjt: KeyedJaggedTensor):
test_pt2_ir_export=True,
)

def test_kjt_length_per_key_meta(self) -> None:
class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor):
return kjt.length_per_key()

kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
kjt = kjt.to("meta")

# calling forward on meta inputs once traced should error out
# as calculating length_per_key requires a .tolist() call of lengths
self.assertRaisesRegex(
RuntimeError,
r".*Tensor\.item\(\) cannot be called on meta tensors.*",
lambda: self._test_kjt_input_module(
M(),
kjt,
(),
test_aot_inductor=False,
test_pt2_ir_export=True,
),
)

def test_kjt_offset_per_key(self) -> None:
class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor):
Expand Down Expand Up @@ -629,7 +654,6 @@ def test_sharded_quant_ebc_non_strict_export(self) -> None:
local_device="cpu", compute_device="cpu"
)
kjt = input_kjts[0]
kjt = kjt.to("meta")
sharded_model(kjt.values(), kjt.lengths())

from torch.export import _trace
Expand All @@ -652,9 +676,6 @@ def test_sharded_quant_ebc_non_strict_export(self) -> None:
for n in ep.graph_module.graph.nodes:
self.assertFalse("auto_functionalized" in str(n.name))

# TODO: Fix Unflatten
# torch.export.unflatten(ep)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
Expand All @@ -665,9 +686,6 @@ def test_sharded_quant_fpebc_non_strict_export(self) -> None:
local_device="cpu", compute_device="cpu", feature_processor=True
)
kjt = input_kjts[0]
kjt = kjt.to("meta")
# Move FP parameters
sharded_model.to("meta")

sharded_model(kjt.values(), kjt.lengths())

Expand All @@ -690,11 +708,6 @@ def test_sharded_quant_fpebc_non_strict_export(self) -> None:
for n in ep.graph_module.graph.nodes:
self.assertFalse("auto_functionalized" in str(n.name))

# The nn_module_stack for this model forms a skip connection that looks like:
# a -> a.b -> a.b.c -> a.d
# This is currently not supported by unflatten.
# torch.export.unflatten(ep)

def test_maybe_compute_kjt_to_jt_dict(self) -> None:
kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
self._test_kjt_input_module(
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ class KeyValueParams:
stats_reporter_config: Optional[TBEStatsReporterConfig] = None
use_passed_in_path: bool = True
l2_cache_size: Optional[int] = None
enable_async_update: Optional[bool] = None

# Parameter Server (PS) Attributes
ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None
Expand All @@ -672,6 +673,7 @@ def __hash__(self) -> int:
self.gather_ssd_cache_stats,
self.stats_reporter_config,
self.l2_cache_size,
self.enable_async_update,
)
)

Expand Down
2 changes: 2 additions & 0 deletions torchrec/inference/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.fused_params import (
FUSED_PARAM_BOUNDS_CHECK_MODE,
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP,
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
FUSED_PARAM_REGISTER_TBE_BOOL,
)
Expand Down Expand Up @@ -82,6 +83,7 @@ def trim_torch_package_prefix_from_typename(typename: str) -> str:
FUSED_PARAM_REGISTER_TBE_BOOL: True,
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: True,
FUSED_PARAM_BOUNDS_CHECK_MODE: BoundsCheckMode.NONE,
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: False,
}

DEFAULT_SHARDERS: List[ModuleSharder[torch.nn.Module]] = [
Expand Down
Loading

0 comments on commit a94f6e2

Please sign in to comment.