Skip to content

add virtual size support for shardedTensor #2915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 121 additions & 12 deletions torchrec/distributed/embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict

import abc
import copy
import logging
from collections import defaultdict, OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union
Expand All @@ -26,7 +27,12 @@
ShardedEmbeddingTable,
)
from torchrec.distributed.shards_wrapper import LocalShardsWrapper
from torchrec.distributed.types import Shard, ShardedTensor, ShardedTensorMetadata
from torchrec.distributed.types import (
Shard,
ShardedTensor,
ShardedTensorMetadata,
ShardMetadata,
)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -56,6 +62,100 @@ def config(self) -> GroupedEmbeddingConfig:
pass


def manually_craft_local_metadata(
local_metadata: ShardMetadata,
param: Union[torch.Tensor, PartiallyMaterializedTensor],
my_rank: int,
) -> None:
local_metadata.shard_sizes = list(param.size())
local_metadata.shard_offsets = [
my_rank if dim == 0 else 0 for dim in range(len(param.size()))
]


def manually_crafted_global_metadata(
metadata: ShardedTensorMetadata,
my_rank: int,
param: Union[torch.Tensor, PartiallyMaterializedTensor],
) -> None:
# update tensor properties from local tensor properties, this should be universal for all ranks
metadata.tensor_properties.dtype = param.dtype
metadata.tensor_properties.requires_grad = param.requires_grad

# manually craft metadata, faking the metadata in a way that all other rank only has 1 row
# to pass the shardedTensor overlapping checks
# NOTE this currently only works for row-wise sharding
fake_total_rows = param.size()[0] + len(metadata.shards_metadata) - 1
metadata.size = torch.Size(
[
fake_total_rows if dim == 0 else param.size(dim)
for dim in range(len(param.size()))
]
)

for rank, shard_metadata in enumerate(metadata.shards_metadata):
if rank < my_rank:
shard_metadata.shard_sizes = [
1 if dim == 0 else param.size(dim) for dim in range(len(param.size()))
]
shard_metadata.shard_offsets = [
rank if dim == 0 else 0 for dim in range(len(param.size()))
]
elif rank == my_rank:
manually_craft_local_metadata(shard_metadata, param, my_rank)
else:
shard_metadata.shard_sizes = [
1 if dim == 0 else param.size(dim) for dim in range(len(param.size()))
]
shard_metadata.shard_offsets = [
rank + param.size(0) - 1 if dim == 0 else 0
for dim in range(len(param.size()))
]


def create_virtual_sharded_tensors(
embedding_tables: List[ShardedEmbeddingTable],
params: Union[List[torch.Tensor], List[PartiallyMaterializedTensor]],
pg: Optional[dist.ProcessGroup] = None,
prefix: str = "",
) -> List[ShardedTensor]:
key_to_local_shards: Dict[str, List[Shard]] = defaultdict(list)
key_to_global_metadata: Dict[str, ShardedTensorMetadata] = {}

def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
return prefix + f"{embedding_table.name}"

my_rank = dist.get_rank()
for embedding_table, param in zip(embedding_tables, params):
key = get_key_from_embedding_table(embedding_table)
assert embedding_table.kv_zch

assert embedding_table.global_metadata is not None and pg is not None
global_metadata = copy.deepcopy(embedding_table.global_metadata)
manually_crafted_global_metadata(global_metadata, my_rank, param)
key_to_global_metadata[key] = global_metadata

assert embedding_table.local_metadata is not None
local_metadata = copy.deepcopy(embedding_table.local_metadata)
manually_craft_local_metadata(local_metadata, param, my_rank)

key_to_local_shards[key].append(Shard(param, local_metadata))

result: List[ShardedTensor] = []
if pg is not None:
for key in key_to_local_shards:
global_metadata = key_to_global_metadata[key]
result.append(
ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=key_to_local_shards[key],
sharded_tensor_metadata=global_metadata,
process_group=pg,
skip_tensor_size_check=True,
)
)
return result


def get_state_dict(
embedding_tables: List[ShardedEmbeddingTable],
params: Union[
Expand All @@ -68,6 +168,7 @@ def get_state_dict(
pg: Optional[dist.ProcessGroup] = None,
destination: Optional[Dict[str, Any]] = None,
prefix: str = "",
use_virtual_size: bool = False,
) -> Dict[str, Any]:
if destination is None:
destination = OrderedDict()
Expand All @@ -84,11 +185,15 @@ def get_state_dict(
# pyre-ignore[33]
key_to_local_tensor_shards: Dict[str, List[Any]] = defaultdict(list)

# validate on the function input for kv zch cases
for emb_table in embedding_tables:
assert use_virtual_size == emb_table.kv_zch

def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
return prefix + f"{embedding_table.name}.weight"

for embedding_table, param in zip(embedding_tables, params):
key = get_key_from_embedding_table(embedding_table)
weights_key = get_key_from_embedding_table(embedding_table)
is_quant = embedding_table.compute_kernel in [
EmbeddingComputeKernel.QUANT,
EmbeddingComputeKernel.QUANT_UVM,
Expand All @@ -103,17 +208,18 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
qbias = param[2]
param = param[0]

assert embedding_table.local_rows == param.size( # pyre-ignore[16]
0
), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}" # pyre-ignore[16]
if not embedding_table.kv_zch:
assert embedding_table.local_rows == param.size( # pyre-ignore[16]
0
), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}" # pyre-ignore[16]

if qscale is not None:
assert embedding_table.local_cols == param.size(1) # pyre-ignore[16]

if embedding_table.dtensor_metadata is not None and pg is not None:
# DTensor path
key_to_dtensor_metadata[key] = embedding_table.dtensor_metadata
key_to_local_tensor_shards[key].append(
key_to_dtensor_metadata[weights_key] = embedding_table.dtensor_metadata
key_to_local_tensor_shards[weights_key].append(
[
param,
embedding_table.local_metadata.shard_offsets, # pyre-ignore[16]
Expand All @@ -127,21 +233,23 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
embedding_table.global_metadata.tensor_properties.requires_grad = (
param.requires_grad # pyre-ignore[16]
)
key_to_global_metadata[key] = embedding_table.global_metadata
key_to_global_metadata[weights_key] = embedding_table.global_metadata

key_to_local_shards[key].append(
# for kv zch cases, we use virtual space, the logic will be the same as non-kv zch cases
key_to_local_shards[weights_key].append(
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Module, Tensor]`.
# pyre-fixme[6]: For 2nd argument expected `ShardMetadata` but got
# `Optional[ShardMetadata]`.
Shard(param, embedding_table.local_metadata)
)

else:
destination[key] = param
destination[weights_key] = param
if qscale is not None:
destination[f"{key}_qscale"] = qscale
destination[f"{weights_key}_qscale"] = qscale
if qbias is not None:
destination[f"{key}_qbias"] = qbias
destination[f"{weights_key}_qbias"] = qbias

if pg is not None:
# Populate the remaining destinations that have a global metadata
Expand All @@ -152,6 +260,7 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
local_shards=key_to_local_shards[key],
sharded_tensor_metadata=global_metadata,
process_group=pg,
skip_tensor_size_check=use_virtual_size,
)
)
# DTensor path
Expand Down
Loading