Skip to content

support zero collision tables in ssd operator #4033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ def from_str(cls, key: str):
raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")


class BackendType(enum.IntEnum):
SSD = 0
DRAM = 1
PS = 2

@classmethod
# pyre-ignore[3]
def from_str(cls, key: str):
lookup = {
"ssd": BackendType.SSD,
"dram": BackendType.DRAM,
}
if key in lookup:
return lookup[key]
else:
raise ValueError(f"Cannot parse value into BackendType: {key}")


class CacheAlgorithm(enum.Enum):
LRU = 0
LFU = 1
Expand Down
200 changes: 162 additions & 38 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BackendType,
BoundsCheckMode,
CacheAlgorithm,
EmbeddingLocation,
Expand Down Expand Up @@ -89,7 +90,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module):

def __init__(
self,
embedding_specs: List[Tuple[int, int]], # tuple of (rows, dims)
# tuple of (rows, dims) for mch embedding or
# tuple of (virtual global table size, dims, virtual local table size) for kv embedding
# virtual global table size is needed when we take in original input id
# virtual local table size is needed when we create shardedTensor during init
# emb table size is 0 during init, which can not pass the ShardedTensor construction checking
embedding_specs: Union[List[Tuple[int, int]], List[Tuple[int, int, int]]],
feature_table_map: Optional[List[int]], # [T]
cache_sets: int,
ssd_storage_directory: str,
Expand Down Expand Up @@ -152,15 +158,22 @@ def __init__(
# number of rows will be decided by bulk_init_chunk_size / size_of_each_row
bulk_init_chunk_size: int = 0,
lazy_bulk_init_enabled: bool = False,
zero_collision_tbe: Optional[bool] = False,
backend_type: BackendType = BackendType.SSD,
# global bucket id start and global bucket id end offsets for each logical table,
# where start offset is inclusive and end offset is exclusive
bucket_offsets: Optional[List[Tuple[int, int]]] = None,
# one bucket size for each logical table
# the value indicates corresponding input space for each bucket id, e.g. 2^50 / total_num_buckets
bucket_sizes: Optional[List[int]] = None,
) -> None:
super(SSDTableBatchedEmbeddingBags, self).__init__()

self.pooling_mode = pooling_mode
self.bounds_check_mode_int: int = bounds_check_mode.value
self.embedding_specs = embedding_specs
(rows, dims) = zip(*embedding_specs)
T_ = len(self.embedding_specs)
assert T_ > 0
rows, dims, *_ = zip(*embedding_specs)
# pyre-fixme[8]: Attribute has type `device`; used as `int`.
self.current_device: torch.device = torch.cuda.current_device()

Expand Down Expand Up @@ -426,6 +439,19 @@ def __init__(
)
# logging.info("DEBUG: weights_precision {}".format(weights_precision))

# zero collision TBE configurations
self.zero_collision_tbe = zero_collision_tbe
self.backend_type = backend_type
self.bucket_offsets: Optional[List[Tuple[int, int]]] = bucket_offsets
self.bucket_sizes: Optional[List[int]] = bucket_sizes
if self.zero_collision_tbe:
assert (
bucket_offsets is not None and bucket_sizes is not None
), "bucket_offsets and bucket_sizes must be provided for zero_collision_tbe"
assert len(bucket_offsets) == len(
bucket_sizes
), "bucket_offsets and bucket_sizes must have the same length"

# create tbe unique id using rank index | local tbe idx
if tbe_unique_id == -1:
SSDTableBatchedEmbeddingBags._local_instance_index += 1
Expand All @@ -442,7 +468,7 @@ def __init__(
tbe_unique_id = SSDTableBatchedEmbeddingBags._local_instance_index
self.tbe_unique_id = tbe_unique_id
logging.info(f"tbe_unique_id: {tbe_unique_id}")
if not ps_hosts:
if self.backend_type == BackendType.SSD:
logging.info(
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, enable_async_update:{enable_async_update}"
f"passed_in_path={ssd_directory}, num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
Expand Down Expand Up @@ -485,11 +511,9 @@ def __init__(
self._lazy_initialize_ssd_tbe()
else:
self._insert_all_kv()
else:
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
elif self.backend_type == BackendType.PS:
self._ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper(
[host[0] for host in ps_hosts],
[host[0] for host in ps_hosts], # pyre-ignore
[host[1] for host in ps_hosts],
tbe_unique_id,
(
Expand All @@ -502,6 +526,9 @@ def __init__(
l2_cache_size,
self.max_D,
)
else:
raise AssertionError(f"Invalid backend type {self.backend_type}")

# pyre-fixme[20]: Argument `self` expected.
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
# GPU stream for SSD cache eviction
Expand Down Expand Up @@ -1718,23 +1745,53 @@ def forward(
)

@torch.jit.ignore
def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor]]:
def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor, int, int]]:
"""
Returns a list of states, split by table
Returns a list of optimizer states, table_input_id_start, table_input_id_end, split by table
Testing only
"""
(rows, _) = zip(*self.embedding_specs)
(rows, _, *_) = zip(*self.embedding_specs)

rows_cumsum = [0] + list(itertools.accumulate(rows))

return [
(
self.momentum1_dev.detach()[rows_cumsum[t] : rows_cumsum[t + 1]].view(
row
),
)
for t, row in enumerate(rows)
]
if self.zero_collision_tbe:
opt_list = []
table_offset = 0
for t, row in enumerate(rows):
# pyre-ignore
bucket_id_start, bucket_id_end = self.bucket_offsets[t]
# pyre-ignore
bucket_size = self.bucket_sizes[t]
table_input_id_start = bucket_id_start * bucket_size + table_offset
table_input_id_end = bucket_id_end * bucket_size + table_offset

# TODO: this is a hack for preallocated optimizer, update this part once we have optimizer offloading
unlinearized_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot(
table_input_id_start,
table_input_id_end,
0, # no need for table offest, as optimizer is preallocated using table offset
None,
)
sorted_offsets, _ = torch.sort(unlinearized_id_tensor.view(-1))
opt_list.append(
(
self.momentum1_dev.detach()[sorted_offsets],
table_input_id_start - table_offset,
table_input_id_end - table_offset,
)
)
table_offset += row
return opt_list
else:
return [
(
self.momentum1_dev.detach()[
rows_cumsum[t] : rows_cumsum[t + 1]
].view(row),
-1,
-1,
)
for t, row in enumerate(rows)
]

@torch.jit.export
def debug_split_embedding_weights(self) -> List[torch.Tensor]:
Expand Down Expand Up @@ -1786,7 +1843,11 @@ def split_embedding_weights(
self,
no_snapshot: bool = True,
should_flush: bool = False,
) -> List[PartiallyMaterializedTensor]:
) -> Tuple[
List[PartiallyMaterializedTensor],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
]:
"""
This method is intended to be used by the checkpointing engine
only.
Expand All @@ -1800,34 +1861,97 @@ def split_embedding_weights(
operation, only set to True when necessary.

Returns:
a list of partially materialized tensors, each representing a table
tuples of 3 lists, each element corresponds to a logical table
1st arg: partially materialized tensors, each representing a table
2nd arg: input id sorted in bucket id ascending order
3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
where for the i th element, we have i + bucket_id_start = global bucket id
"""
# Force device synchronize for now
torch.cuda.synchronize()
# Create a snapshot
if no_snapshot:
snapshot_handle = None
else:
if should_flush:
# Flush L1 and L2 caches
self.flush()
snapshot_handle = self.ssd_db.create_snapshot()
snapshot_handle = None
if self.backend_type == BackendType.SSD:
# Create a rocksdb snapshot
if not no_snapshot:
if should_flush:
# Flush L1 and L2 caches
self.flush()
snapshot_handle = self.ssd_db.create_snapshot()
elif self.backend_type == BackendType.DRAM:
self.flush()

dtype = self.weights_precision.as_dtype()
splits = []
pmt_splits = []
bucket_sorted_id_splits = [] if self.zero_collision_tbe else None
active_id_cnt_per_bucket_split = [] if self.zero_collision_tbe else None

row_offset = 0
for emb_height, emb_dim in self.embedding_specs:
table_offset = 0
for i, (emb_height, emb_dim, *virtual_local_rows) in enumerate(
self.embedding_specs
):
bucket_ascending_id_tensor = None
bucket_t = None
if self.zero_collision_tbe:
# pyre-ignore
bucket_id_start, bucket_id_end = self.bucket_offsets[i]
# pyre-ignore
bucket_size = self.bucket_sizes[i]

# linearize with table offset
table_input_id_start = bucket_id_start * bucket_size + table_offset
table_input_id_end = bucket_id_end * bucket_size + table_offset
# 1. get all keys from backend for one table
unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot(
table_input_id_start,
table_input_id_end,
table_offset,
snapshot_handle,
)
# 2. sorting keys in bucket ascending order
bucket_ascending_id_tensor, bucket_t = (
torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor(
unordered_id_tensor,
0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
bucket_id_start,
bucket_id_end,
bucket_size,
)
)
# pyre-ignore
bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
# pyre-ignore
active_id_cnt_per_bucket_split.append(bucket_t)

if virtual_local_rows:
table_rows = virtual_local_rows[0]
assert (
self.zero_collision_tbe
), "virtual local rows are only supported for KV_ZCH"
else:
table_rows = emb_height
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
db=self.ssd_db,
shape=[emb_height, emb_dim],
shape=[table_rows, emb_dim],
dtype=dtype,
row_offset=row_offset,
row_offset=table_offset,
snapshot_handle=snapshot_handle,
materialized_shape=(
[bucket_ascending_id_tensor.size(0), emb_dim]
if self.zero_collision_tbe
# no_snapshot means it is for trec shardedTensor init,
# we don't need to return the materialized size, in order to pass the virtual shardtensor check
and not no_snapshot
else None
),
)
row_offset += emb_height
splits.append(PartiallyMaterializedTensor(tensor_wrapper))
return splits
table_offset += emb_height
pmt_splits.append(
PartiallyMaterializedTensor(
tensor_wrapper,
self.zero_collision_tbe if self.zero_collision_tbe else False,
)
)
return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)

@torch.jit.export
def set_learning_rate(self, lr: float) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PartiallyMaterializedTensor:
or use `full_tensor()` to get the full tensor (this could OOM).
"""

def __init__(self, wrapped) -> None:
def __init__(self, wrapped, is_virtual_size: bool = False) -> None:
"""
Ensure caller loads the module before creating this object.

Expand All @@ -48,6 +48,7 @@ def __init__(self, wrapped) -> None:
wrapped: torch.classes.fbgemm.KVTensorWrapper
"""
self._wrapped = wrapped
self._is_virtual_size = is_virtual_size
self._requires_grad = False

@property
Expand All @@ -57,6 +58,15 @@ def wrapped(self):
"""
return self._wrapped

@property
def is_virtual(self):
"""
Indicate whether PMT is a virtual tensor, in which the actual tensor shape would be different from virtual size
This indicator is needed for checkpoint or publish. They need to call all-gather to recalculate the correct
metadata of the ShardedTensor
"""
return self._is_virtual_size

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ inline size_t hash_shard(int64_t id, size_t num_shards) {
///
/// @param unordered_indices unordered ids, the id here might be
/// original(unlinearized) id
/// @param hash_mode 0 for hash by mod, 1 for hash by interleave
/// @param hash_mode 0 for chunk-based hashing, 1 for interleaved-based hashing
/// @param bucket_start global bucket id, the start of the bucket range
/// @param bucket_end global bucket id, the end of the bucket range
/// @param bucket_size an optional, virtual size(input space, e.g. 2^50) of a
Expand Down
Loading
Loading