Skip to content

Commit

Permalink
2024-12-05 nightly release (8ca8b65)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Dec 5, 2024
1 parent fe97db9 commit 4b96c2e
Show file tree
Hide file tree
Showing 20 changed files with 657 additions and 164 deletions.
1 change: 0 additions & 1 deletion examples/retrieval/knn_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def get_index(
res = faiss.StandardGpuResources()
# pyre-fixme[16]
config = faiss.GpuIndexIVFPQConfig()
# pyre-ignore[16]
index = faiss.GpuIndexIVFPQ(
res,
embedding_dim,
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,7 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
(
SplitTableBatchedEmbeddingBagsCodegen,
DenseTableBatchedEmbeddingBagsCodegen,
SSDTableBatchedEmbeddingBags,
),
):
return self.emb_module(
Expand Down
8 changes: 0 additions & 8 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,6 @@
except OSError:
pass

try:
from tensordict import TensorDict
except ImportError:

class TensorDict:
pass


logger: logging.Logger = logging.getLogger(__name__)


Expand Down
37 changes: 4 additions & 33 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@
QuantizedCommCodecs,
ShardedTensor,
ShardingEnv,
ShardingEnv2D,
ShardingType,
ShardMetadata,
TensorProperties,
)
from torchrec.distributed.utils import (
add_params_from_parameter_sharding,
Expand Down Expand Up @@ -99,13 +97,6 @@
except OSError:
pass

try:
from tensordict import TensorDict
except ImportError:

class TensorDict:
pass


def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
return (
Expand Down Expand Up @@ -943,31 +934,11 @@ def _initialize_torch_state(self) -> None: # noqa
# created ShardedTensors once in init, use in post_state_dict_hook
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
# access is allowed on them.
sharding_spec = none_throws(
self.module_sharding_plan[table_name].sharding_spec
)
metadata = sharding_spec.build_metadata(
tensor_sizes=self._name_to_table_size[table_name],
tensor_properties=(
TensorProperties(
dtype=local_shards[0].tensor.dtype,
layout=local_shards[0].tensor.layout,
requires_grad=local_shards[0].tensor.requires_grad,
)
if local_shards
else TensorProperties()
),
)

self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=local_shards,
sharded_tensor_metadata=metadata,
process_group=(
self._env.sharding_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group
),
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=self._env.process_group,
)
)

Expand Down
29 changes: 20 additions & 9 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
InferCwSequenceEmbeddingSharding,
)
from torchrec.distributed.sharding.rw_sequence_sharding import (
InferCPURwSequenceEmbeddingSharding,
InferRwSequenceEmbeddingSharding,
)
from torchrec.distributed.sharding.sequence_sharding import InferSequenceShardingContext
Expand Down Expand Up @@ -113,31 +112,43 @@ def create_infer_embedding_sharding(
List[torch.Tensor],
List[torch.Tensor],
]:
device_type = get_device_from_sharding_infos(sharding_infos)
device_type_from_sharding_infos: str = get_device_from_sharding_infos(
sharding_infos
)

if device_type in ["cuda", "mtia"]:
if device_type_from_sharding_infos in ["cuda", "mtia"]:
if sharding_type == ShardingType.TABLE_WISE.value:
return InferTwSequenceEmbeddingSharding(sharding_infos, env, device)
elif sharding_type == ShardingType.COLUMN_WISE.value:
return InferCwSequenceEmbeddingSharding(sharding_infos, env, device)
elif sharding_type == ShardingType.ROW_WISE.value:
return InferRwSequenceEmbeddingSharding(sharding_infos, env, device)
return InferRwSequenceEmbeddingSharding(
sharding_infos=sharding_infos,
env=env,
device=device,
device_type_from_sharding_infos=device_type_from_sharding_infos,
)
else:
raise ValueError(
f"Sharding type not supported {sharding_type} for {device_type} sharding"
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
)
elif device_type == "cpu":
elif device_type_from_sharding_infos == "cpu":
if sharding_type == ShardingType.ROW_WISE.value:
return InferCPURwSequenceEmbeddingSharding(sharding_infos, env, device)
return InferRwSequenceEmbeddingSharding(
sharding_infos=sharding_infos,
env=env,
device=device,
device_type_from_sharding_infos=device_type_from_sharding_infos,
)
elif sharding_type == ShardingType.TABLE_WISE.value:
return InferTwSequenceEmbeddingSharding(sharding_infos, env, device)
else:
raise ValueError(
f"Sharding type not supported {sharding_type} for {device_type} sharding"
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
)
else:
raise ValueError(
f"Sharding type not supported {sharding_type} for {device_type} sharding"
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
)


Expand Down
103 changes: 13 additions & 90 deletions torchrec/distributed/sharding/rw_sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,26 @@ def __init__(
self,
device: torch.device,
world_size: int,
device_type_from_sharding_infos: Optional[str] = None,
) -> None:
super().__init__()
self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size)
self._device_type_from_sharding_infos: Optional[str] = (
device_type_from_sharding_infos
)

def forward(
self,
local_embs: List[torch.Tensor],
sharding_ctx: Optional[InferSequenceShardingContext] = None,
) -> List[torch.Tensor]:
return self._dist(local_embs)
# for cpu sharder, output dist should be a no-op
return (
local_embs
if self._device_type_from_sharding_infos is not None
and self._device_type_from_sharding_infos == "cpu"
else self._dist(local_embs)
)


class InferRwSequenceEmbeddingSharding(
Expand All @@ -202,6 +212,7 @@ def create_input_dist(
(emb_sharding, is_even_sharding) = get_embedding_shard_metadata(
self._grouped_embedding_configs_per_rank
)

return InferRwSparseFeaturesDist(
world_size=self._world_size,
num_features=num_features,
Expand Down Expand Up @@ -235,93 +246,5 @@ def create_output_dist(
return InferRwSequenceEmbeddingDist(
device if device is not None else self._device,
self._world_size,
)


class InferCPURwSequenceEmbeddingDist(
BaseEmbeddingDist[
InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor]
]
):
def __init__(
self,
device: torch.device,
world_size: int,
) -> None:
super().__init__()

def forward(
self,
local_embs: List[torch.Tensor],
sharding_ctx: Optional[InferSequenceShardingContext] = None,
) -> List[torch.Tensor]:
# for cpu sharder, output dist should be a no-op
return local_embs


class InferCPURwSequenceEmbeddingSharding(
BaseRwEmbeddingSharding[
InferSequenceShardingContext,
InputDistOutputs,
List[torch.Tensor],
List[torch.Tensor],
]
):
"""
Shards sequence (unpooled) row-wise, i.e.. a given embedding table is evenly
distributed by rows and table slices are placed on all ranks for inference.
"""

def create_input_dist(
self,
device: Optional[torch.device] = None,
) -> BaseSparseFeaturesDist[InputDistOutputs]:
num_features = self._get_num_features()
feature_hash_sizes = self._get_feature_hash_sizes()

emb_sharding = []
for embedding_table_group in self._grouped_embedding_configs_per_rank[0]:
for table in embedding_table_group.embedding_tables:
shard_split_offsets = [
shard.shard_offsets[0]
# pyre-fixme[16]: `Optional` has no attribute `shards_metadata`.
for shard in table.global_metadata.shards_metadata
]
# pyre-fixme[16]: Optional has no attribute size.
shard_split_offsets.append(table.global_metadata.size[0])
emb_sharding.extend([shard_split_offsets] * len(table.embedding_names))

return InferRwSparseFeaturesDist(
world_size=self._world_size,
num_features=num_features,
feature_hash_sizes=feature_hash_sizes,
device=device if device is not None else self._device,
is_sequence=True,
has_feature_processor=self._has_feature_processor,
need_pos=False,
embedding_shard_metadata=emb_sharding,
)

def create_lookup(
self,
device: Optional[torch.device] = None,
fused_params: Optional[Dict[str, Any]] = None,
feature_processor: Optional[BaseGroupedFeatureProcessor] = None,
) -> BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]]:
return InferCPUGroupedEmbeddingsLookup(
grouped_configs_per_rank=self._grouped_embedding_configs_per_rank,
world_size=self._world_size,
fused_params=fused_params,
device=device if device is not None else self._device,
)

def create_output_dist(
self,
device: Optional[torch.device] = None,
) -> BaseEmbeddingDist[
InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor]
]:
return InferCPURwSequenceEmbeddingDist(
device if device is not None else self._device,
self._world_size,
self._device_type_from_sharding_infos,
)
4 changes: 4 additions & 0 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
device: Optional[torch.device] = None,
need_pos: bool = False,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
device_type_from_sharding_infos: Optional[str] = None,
) -> None:
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
self._env = env
Expand All @@ -132,6 +133,9 @@ def __init__(
if device is None:
device = torch.device("cpu")
self._device: torch.device = device
self._device_type_from_sharding_infos: Optional[str] = (
device_type_from_sharding_infos
)
sharded_tables_per_rank = self._shard(sharding_infos)
self._need_pos = need_pos
self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = (
Expand Down
5 changes: 5 additions & 0 deletions torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,15 @@
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection
from torchrec.quant.embedding_modules import (
EmbeddingCollection as QuantEmbeddingCollection,
FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection,
MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
MODULE_ATTR_REGISTER_TBES_BOOL,
quant_prep_enable_quant_state_dict_split_scale_bias_for_types,
quant_prep_enable_register_tbes,
QuantManagedCollisionEmbeddingCollection,
)


Expand Down Expand Up @@ -331,6 +333,7 @@ def quantize(
module_types: List[Type[torch.nn.Module]] = [
torchrec.modules.embedding_modules.EmbeddingBagCollection,
torchrec.modules.embedding_modules.EmbeddingCollection,
torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingCollection,
]
if register_tbes:
quant_prep_enable_register_tbes(module, module_types)
Expand All @@ -356,10 +359,12 @@ def quantize(
qconfig_spec={
EmbeddingBagCollection: qconfig,
EmbeddingCollection: qconfig,
ManagedCollisionEmbeddingCollection: qconfig,
},
mapping={
EmbeddingBagCollection: QuantEmbeddingBagCollection,
EmbeddingCollection: QuantEmbeddingCollection,
ManagedCollisionEmbeddingCollection: QuantManagedCollisionEmbeddingCollection,
},
inplace=inplace,
)
Expand Down
7 changes: 3 additions & 4 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2195,8 +2195,7 @@ def test_sharded_quant_mc_ec_rw(
)
quant_model = mi.quant_model
assert quant_model.training is False
print(f"quant_model:\n{quant_model}")
non_sharded_output, _ = mi.quant_model(*inputs[0])
non_sharded_output = mi.quant_model(*inputs[0])

topology: Topology = Topology(world_size=world_size, compute_device=device_type)
mi.planner = EmbeddingShardingPlanner(
Expand Down Expand Up @@ -2231,7 +2230,7 @@ def test_sharded_quant_mc_ec_rw(
print(f"sharded_model.MODULE[{n}]:{type(m)}")

sharded_model.load_state_dict(quant_model.state_dict())
sharded_output, _ = sharded_model(*inputs[0])
sharded_output = sharded_model(*inputs[0])

assert_close(non_sharded_output, sharded_output)
gm: torch.fx.GraphModule = symbolic_trace(
Expand All @@ -2245,7 +2244,7 @@ def test_sharded_quant_mc_ec_rw(
print(f"fx.graph:\n{gm.graph}")
gm_script = torch.jit.script(gm)
print(f"gm_script:\n{gm_script}")
gm_script_output, _ = gm_script(*inputs[0])
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

@unittest.skipIf(
Expand Down
Loading

0 comments on commit 4b96c2e

Please sign in to comment.