diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index 741a87f965..66563c8f46 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -49,6 +49,24 @@ def from_str(cls, key: str): raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}") +class BackendType(enum.IntEnum): + SSD = 0 + DRAM = 1 + PS = 2 + + @classmethod + # pyre-ignore[3] + def from_str(cls, key: str): + lookup = { + "ssd": BackendType.SSD, + "dram": BackendType.DRAM, + } + if key in lookup: + return lookup[key] + else: + raise ValueError(f"Cannot parse value into BackendType: {key}") + + class CacheAlgorithm(enum.Enum): LRU = 0 LFU = 1 diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 58ba227f27..cd58c05251 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -29,6 +29,7 @@ ) from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( + BackendType, BoundsCheckMode, CacheAlgorithm, EmbeddingLocation, @@ -89,7 +90,12 @@ class SSDTableBatchedEmbeddingBags(nn.Module): def __init__( self, - embedding_specs: List[Tuple[int, int]], # tuple of (rows, dims) + # tuple of (rows, dims) for mch embedding or + # tuple of (virtual global table size, dims, virtual local table size) for kv embedding + # virtual global table size is needed when we take in original input id + # virtual local table size is needed when we create shardedTensor during init + # emb table size is 0 during init, which can not pass the ShardedTensor construction checking + embedding_specs: Union[List[Tuple[int, int]], List[Tuple[int, int, int]]], feature_table_map: Optional[List[int]], # [T] cache_sets: int, ssd_storage_directory: str, @@ -152,15 +158,22 @@ def __init__( # number of rows will be decided by bulk_init_chunk_size / size_of_each_row bulk_init_chunk_size: int = 0, lazy_bulk_init_enabled: bool = False, + zero_collision_tbe: Optional[bool] = False, + backend_type: BackendType = BackendType.SSD, + # global bucket id start and global bucket id end offsets for each logical table, + # where start offset is inclusive and end offset is exclusive + bucket_offsets: Optional[List[Tuple[int, int]]] = None, + # one bucket size for each logical table + # the value indicates corresponding input space for each bucket id, e.g. 2^50 / total_num_buckets + bucket_sizes: Optional[List[int]] = None, ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() - self.pooling_mode = pooling_mode self.bounds_check_mode_int: int = bounds_check_mode.value self.embedding_specs = embedding_specs - (rows, dims) = zip(*embedding_specs) T_ = len(self.embedding_specs) assert T_ > 0 + rows, dims, *_ = zip(*embedding_specs) # pyre-fixme[8]: Attribute has type `device`; used as `int`. self.current_device: torch.device = torch.cuda.current_device() @@ -426,6 +439,19 @@ def __init__( ) # logging.info("DEBUG: weights_precision {}".format(weights_precision)) + # zero collision TBE configurations + self.zero_collision_tbe = zero_collision_tbe + self.backend_type = backend_type + self.bucket_offsets: Optional[List[Tuple[int, int]]] = bucket_offsets + self.bucket_sizes: Optional[List[int]] = bucket_sizes + if self.zero_collision_tbe: + assert ( + bucket_offsets is not None and bucket_sizes is not None + ), "bucket_offsets and bucket_sizes must be provided for zero_collision_tbe" + assert len(bucket_offsets) == len( + bucket_sizes + ), "bucket_offsets and bucket_sizes must have the same length" + # create tbe unique id using rank index | local tbe idx if tbe_unique_id == -1: SSDTableBatchedEmbeddingBags._local_instance_index += 1 @@ -442,7 +468,7 @@ def __init__( tbe_unique_id = SSDTableBatchedEmbeddingBags._local_instance_index self.tbe_unique_id = tbe_unique_id logging.info(f"tbe_unique_id: {tbe_unique_id}") - if not ps_hosts: + if self.backend_type == BackendType.SSD: logging.info( f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, enable_async_update:{enable_async_update}" f"passed_in_path={ssd_directory}, num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards}," @@ -485,11 +511,9 @@ def __init__( self._lazy_initialize_ssd_tbe() else: self._insert_all_kv() - else: - # pyre-fixme[4]: Attribute must be annotated. - # pyre-ignore[16] + elif self.backend_type == BackendType.PS: self._ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper( - [host[0] for host in ps_hosts], + [host[0] for host in ps_hosts], # pyre-ignore [host[1] for host in ps_hosts], tbe_unique_id, ( @@ -502,6 +526,9 @@ def __init__( l2_cache_size, self.max_D, ) + else: + raise AssertionError(f"Invalid backend type {self.backend_type}") + # pyre-fixme[20]: Argument `self` expected. (low_priority, high_priority) = torch.cuda.Stream.priority_range() # GPU stream for SSD cache eviction @@ -1718,23 +1745,53 @@ def forward( ) @torch.jit.ignore - def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor]]: + def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor, int, int]]: """ - Returns a list of states, split by table + Returns a list of optimizer states, table_input_id_start, table_input_id_end, split by table Testing only """ - (rows, _) = zip(*self.embedding_specs) + (rows, _, *_) = zip(*self.embedding_specs) rows_cumsum = [0] + list(itertools.accumulate(rows)) - - return [ - ( - self.momentum1_dev.detach()[rows_cumsum[t] : rows_cumsum[t + 1]].view( - row - ), - ) - for t, row in enumerate(rows) - ] + if self.zero_collision_tbe: + opt_list = [] + table_offset = 0 + for t, row in enumerate(rows): + # pyre-ignore + bucket_id_start, bucket_id_end = self.bucket_offsets[t] + # pyre-ignore + bucket_size = self.bucket_sizes[t] + table_input_id_start = bucket_id_start * bucket_size + table_offset + table_input_id_end = bucket_id_end * bucket_size + table_offset + + # TODO: this is a hack for preallocated optimizer, update this part once we have optimizer offloading + unlinearized_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot( + table_input_id_start, + table_input_id_end, + 0, # no need for table offest, as optimizer is preallocated using table offset + None, + ) + sorted_offsets, _ = torch.sort(unlinearized_id_tensor.view(-1)) + opt_list.append( + ( + self.momentum1_dev.detach()[sorted_offsets], + table_input_id_start - table_offset, + table_input_id_end - table_offset, + ) + ) + table_offset += row + return opt_list + else: + return [ + ( + self.momentum1_dev.detach()[ + rows_cumsum[t] : rows_cumsum[t + 1] + ].view(row), + -1, + -1, + ) + for t, row in enumerate(rows) + ] @torch.jit.export def debug_split_embedding_weights(self) -> List[torch.Tensor]: @@ -1786,7 +1843,11 @@ def split_embedding_weights( self, no_snapshot: bool = True, should_flush: bool = False, - ) -> List[PartiallyMaterializedTensor]: + ) -> Tuple[ + List[PartiallyMaterializedTensor], + Optional[List[torch.Tensor]], + Optional[List[torch.Tensor]], + ]: """ This method is intended to be used by the checkpointing engine only. @@ -1800,34 +1861,97 @@ def split_embedding_weights( operation, only set to True when necessary. Returns: - a list of partially materialized tensors, each representing a table + tuples of 3 lists, each element corresponds to a logical table + 1st arg: partially materialized tensors, each representing a table + 2nd arg: input id sorted in bucket id ascending order + 3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start] + where for the i th element, we have i + bucket_id_start = global bucket id """ # Force device synchronize for now torch.cuda.synchronize() - # Create a snapshot - if no_snapshot: - snapshot_handle = None - else: - if should_flush: - # Flush L1 and L2 caches - self.flush() - snapshot_handle = self.ssd_db.create_snapshot() + snapshot_handle = None + if self.backend_type == BackendType.SSD: + # Create a rocksdb snapshot + if not no_snapshot: + if should_flush: + # Flush L1 and L2 caches + self.flush() + snapshot_handle = self.ssd_db.create_snapshot() + elif self.backend_type == BackendType.DRAM: + self.flush() dtype = self.weights_precision.as_dtype() - splits = [] + pmt_splits = [] + bucket_sorted_id_splits = [] if self.zero_collision_tbe else None + active_id_cnt_per_bucket_split = [] if self.zero_collision_tbe else None - row_offset = 0 - for emb_height, emb_dim in self.embedding_specs: + table_offset = 0 + for i, (emb_height, emb_dim, *virtual_local_rows) in enumerate( + self.embedding_specs + ): + bucket_ascending_id_tensor = None + bucket_t = None + if self.zero_collision_tbe: + # pyre-ignore + bucket_id_start, bucket_id_end = self.bucket_offsets[i] + # pyre-ignore + bucket_size = self.bucket_sizes[i] + + # linearize with table offset + table_input_id_start = bucket_id_start * bucket_size + table_offset + table_input_id_end = bucket_id_end * bucket_size + table_offset + # 1. get all keys from backend for one table + unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot( + table_input_id_start, + table_input_id_end, + table_offset, + snapshot_handle, + ) + # 2. sorting keys in bucket ascending order + bucket_ascending_id_tensor, bucket_t = ( + torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor( + unordered_id_tensor, + 0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing + bucket_id_start, + bucket_id_end, + bucket_size, + ) + ) + # pyre-ignore + bucket_sorted_id_splits.append(bucket_ascending_id_tensor) + # pyre-ignore + active_id_cnt_per_bucket_split.append(bucket_t) + + if virtual_local_rows: + table_rows = virtual_local_rows[0] + assert ( + self.zero_collision_tbe + ), "virtual local rows are only supported for KV_ZCH" + else: + table_rows = emb_height tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( db=self.ssd_db, - shape=[emb_height, emb_dim], + shape=[table_rows, emb_dim], dtype=dtype, - row_offset=row_offset, + row_offset=table_offset, snapshot_handle=snapshot_handle, + materialized_shape=( + [bucket_ascending_id_tensor.size(0), emb_dim] + if self.zero_collision_tbe + # no_snapshot means it is for trec shardedTensor init, + # we don't need to return the materialized size, in order to pass the virtual shardtensor check + and not no_snapshot + else None + ), ) - row_offset += emb_height - splits.append(PartiallyMaterializedTensor(tensor_wrapper)) - return splits + table_offset += emb_height + pmt_splits.append( + PartiallyMaterializedTensor( + tensor_wrapper, + self.zero_collision_tbe if self.zero_collision_tbe else False, + ) + ) + return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split) @torch.jit.export def set_learning_rate(self, lr: float) -> None: diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py index 3d2f24a939..a308fe9357 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py @@ -33,7 +33,7 @@ class PartiallyMaterializedTensor: or use `full_tensor()` to get the full tensor (this could OOM). """ - def __init__(self, wrapped) -> None: + def __init__(self, wrapped, is_virtual_size: bool = False) -> None: """ Ensure caller loads the module before creating this object. @@ -48,6 +48,7 @@ def __init__(self, wrapped) -> None: wrapped: torch.classes.fbgemm.KVTensorWrapper """ self._wrapped = wrapped + self._is_virtual_size = is_virtual_size self._requires_grad = False @property @@ -57,6 +58,15 @@ def wrapped(self): """ return self._wrapped + @property + def is_virtual(self): + """ + Indicate whether PMT is a virtual tensor, in which the actual tensor shape would be different from virtual size + This indicator is needed for checkpoint or publish. They need to call all-gather to recalculate the correct + metadata of the ShardedTensor + """ + return self._is_virtual_size + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h index 5184886868..980ca5ae1f 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h @@ -46,7 +46,7 @@ inline size_t hash_shard(int64_t id, size_t num_shards) { /// /// @param unordered_indices unordered ids, the id here might be /// original(unlinearized) id -/// @param hash_mode 0 for hash by mod, 1 for hash by interleave +/// @param hash_mode 0 for chunk-based hashing, 1 for interleaved-based hashing /// @param bucket_start global bucket id, the start of the bucket range /// @param bucket_end global bucket id, the end of the bucket range /// @param bucket_size an optional, virtual size(input space, e.g. 2^50) of a diff --git a/fbgemm_gpu/src/split_embeddings_cache/kv_db_cpp_utils.cpp b/fbgemm_gpu/src/split_embeddings_cache/kv_db_cpp_utils.cpp index 6bbf504b2b..7be73e244a 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/kv_db_cpp_utils.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/kv_db_cpp_utils.cpp @@ -23,10 +23,10 @@ int64_t _get_bucket_id( std::optional total_num_buckets = std::nullopt) { if (hash_mode == 0) { CHECK(bucket_size.has_value()); - // hash by mod + // chunk-based hashing return id / bucket_size.value(); } else { - // hash by interleave + // interleave-based hashing CHECK(total_num_buckets.has_value()); return id % total_num_buckets.value(); } @@ -42,7 +42,7 @@ std::tuple get_bucket_sorted_indices_and_bucket_tensor( TORCH_CHECK(unordered_indices.is_contiguous()); TORCH_CHECK( hash_mode == 0 || hash_mode == 1, - "only support hash by mod and interleaved for now"); + "only support hash by chunk-based or interleaved-based hashing for now"); TORCH_CHECK( bucket_start <= bucket_end, "bucket_start:", @@ -73,11 +73,16 @@ std::tuple get_bucket_sorted_indices_and_bucket_tensor( for (int64_t i = 0; i < num_indices; ++i) { auto global_bucket_id = _get_bucket_id( indices_data_ptr[i], hash_mode, bucket_size, total_num_buckets); - CHECK(global_bucket_id >= bucket_start && global_bucket_id < bucket_end) - << "indices: " << indices_data_ptr[i] - << " bucket id: " << global_bucket_id - << " must fall into the range between:" << bucket_start << " and " - << bucket_end; + TORCH_CHECK( + global_bucket_id >= bucket_start && global_bucket_id < bucket_end, + "indices: ", + indices_data_ptr[i], + " bucket id: ", + global_bucket_id, + " must fall into the range between:", + bucket_start, + " and ", + bucket_end); if (bucket_id_to_cnt.find(global_bucket_id) == bucket_id_to_cnt.end()) { bucket_id_to_cnt[global_bucket_id] = 0; } diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h index 723ac0193e..401ad0d49c 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h @@ -86,9 +86,14 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { int64_t start_id, int64_t end_id, int64_t id_offset, - c10::intrusive_ptr snapshot_handle) { + std::optional> + snapshot_handle) { return impl_->get_keys_in_range_by_snapshot( - start_id, end_id, id_offset, snapshot_handle->handle); + start_id, + end_id, + id_offset, + snapshot_handle.has_value() ? snapshot_handle.value()->handle + : nullptr); } void toggle_compaction(bool enable) { diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h index 4aebd39a36..ed6568d4c0 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h @@ -37,7 +37,8 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder { int64_t dtype, int64_t row_offset, std::optional> - snapshot_handle); + snapshot_handle = std::nullopt, + std::optional> materialized_shape = std::nullopt); at::Tensor narrow(int64_t dim, int64_t start, int64_t length); @@ -70,6 +71,7 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder { c10::intrusive_ptr snapshot_handle_; at::TensorOptions options_; std::vector shape_; + std::optional> materialized_shape_; std::vector strides_; int64_t row_offset_; }; diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp index caa83b8643..a4372a4e28 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp @@ -34,9 +34,13 @@ KVTensorWrapper::KVTensorWrapper( [[maybe_unused]] int64_t dtype, int64_t row_offset, [[maybe_unused]] std::optional< - c10::intrusive_ptr> snapshot_handle) + c10::intrusive_ptr> snapshot_handle, + [[maybe_unused]] std::optional> materialized_shape) // @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn - : db_(db->impl_), shape_(std::move(shape)), row_offset_(row_offset) { + : db_(db->impl_), + shape_(std::move(shape)), + materialized_shape_(materialized_shape), + row_offset_(row_offset) { FBEXCEPTION("Not implemented"); } diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 047b76963f..ffc25ccaa8 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -308,8 +308,12 @@ KVTensorWrapper::KVTensorWrapper( int64_t dtype, int64_t row_offset, std::optional> - snapshot_handle) - : db_(db->impl_), shape_(std::move(shape)), row_offset_(row_offset) { + snapshot_handle, + std::optional> materialized_shape) + : db_(db->impl_), + shape_(std::move(shape)), + materialized_shape_(materialized_shape), + row_offset_(row_offset) { CHECK_EQ(shape_.size(), 2) << "Only 2D emb tensors are supported"; options_ = at::TensorOptions() .dtype(static_cast(dtype)) @@ -380,7 +384,11 @@ at::Tensor KVTensorWrapper::get_weights_by_ids(const at::Tensor& ids) { } c10::IntArrayRef KVTensorWrapper::sizes() { - return shape_; + if (materialized_shape_.has_value()) { + return materialized_shape_.value(); + } else { + return shape_; + } } c10::IntArrayRef KVTensorWrapper::strides() { @@ -521,7 +529,8 @@ static auto kv_tensor_wrapper = int64_t, int64_t, std::optional< - c10::intrusive_ptr>>(), + c10::intrusive_ptr>, + std::optional>>(), "", {torch::arg("db"), torch::arg("shape"), @@ -529,7 +538,8 @@ static auto kv_tensor_wrapper = torch::arg("row_offset"), // snapshot must be provided for reading // not needed for writing - torch::arg("snapshot_handle") = std::nullopt}) + torch::arg("snapshot_handle") = std::nullopt, + torch::arg("materialized_shape") = std::nullopt}) .def( "narrow", &KVTensorWrapper::narrow, @@ -544,8 +554,7 @@ static auto kv_tensor_wrapper = .def_property( "shape", &KVTensorWrapper::sizes, - std::string( - "Returns the shape of the original tensor. Only the narrowed part is materialized.")) + std::string("Returns the shape of the original tensor.")) .def_property("strides", &KVTensorWrapper::strides); static auto dram_kv_embedding_cache_wrapper = diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 4750ba72cb..568f7d0e0c 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -547,7 +547,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { } at::Tensor returned_keys = at::empty( - total_num, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + {total_num, 1}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); auto key_ptr = returned_keys.data_ptr(); int64_t offset = 0; for (const auto& keys : keys_in_db_shards) { diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py index 5b9d6d7ab5..bf202d4918 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py @@ -269,7 +269,6 @@ def test_rocksdb_get_discrete_ids( mixed: bool, weights_precision: SparseType, ) -> None: - weights_precision: SparseType = SparseType.FP32 emb, Es, Ds, max_D = self.generate_fbgemm_ssd_tbe( T, D, log_E, weights_precision, mixed, False, 8 ) @@ -307,7 +306,7 @@ def test_rocksdb_get_discrete_ids( start_id + offset, end_id + offset, offset, snapshot ) ids_in_range_ordered, _ = torch.sort(ids_in_range) - id_tensor_ordered, _ = torch.sort(id_tensor) + id_tensor_ordered, _ = torch.sort(id_tensor.view(-1)) assert torch.equal(ids_in_range_ordered, id_tensor_ordered) @@ -378,7 +377,8 @@ def test_get_bucket_sorted_indices( else: # test failure with self.assertRaisesRegex( - RuntimeError, "only support hash by mod and interleaved for now" + RuntimeError, + "only support hash by chunk-based or interleaved-based hashing for now", ): torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor( indices, @@ -401,7 +401,7 @@ def test_get_bucket_sorted_indices( last_bucket_id = cur_bucket_id # Calculate expected tensor output expected_bucket_tensor = torch.zeros( - bucket_end - bucket_start, 1, dtype=torch.int64 + bucket_end - bucket_start, dtype=torch.int64 ) for index in indices: self.assertTrue(hash_mode >= 0 and hash_mode <= 1) @@ -413,4 +413,4 @@ def test_get_bucket_sorted_indices( expected_bucket_tensor[bucket_id - bucket_start] += 1 # Compare actual and expected tensor outputs - self.assertTrue(torch.equal(bucket_t, expected_bucket_tensor)) + self.assertTrue(torch.equal(bucket_t.view(-1), expected_bucket_tensor)) diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index 7f7bd06fea..b8141c4e0a 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -7,6 +7,7 @@ # pyre-strict # pyre-ignore-all-errors[3,6,56] +import math import unittest from enum import Enum @@ -17,6 +18,7 @@ import torch from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( + BackendType, BoundsCheckMode, PoolingMode, ) @@ -30,6 +32,7 @@ MAX_EXAMPLES = 40 MAX_PIPELINE_EXAMPLES = 10 +KV_WORLD_SIZE = 4 default_st: Dict["str", Any] = { "T": st.integers(min_value=1, max_value=10), @@ -130,6 +133,37 @@ def test_ssd(self, indice_int64_t: bool, weights_precision: SparseType) -> None: torch.cuda.synchronize() torch.testing.assert_close(weights, output_weights) + def generate_in_bucket_indices( + self, + hash_mode: int, + bucket_id_range: Tuple[int, int], + bucket_size: int, + # max height in ref_emb, the logical id high, physically id in kv is a shift from [0,h) to [table_offset, table_offset+h] + high: int, + size: Tuple[int, int], + ) -> torch.Tensor: + """ + Generate indices in embedding bucket, this is guarantee on the torchrec input_dist + """ + assert hash_mode == 0, "only support hash_mode=0, aka chunk-based hashing" + + # hash mode is chunk-based hashing + # STEP 1: generate all the eligible indices in the given range + bucket_id_start = bucket_id_range[0] + bucket_id_end = bucket_id_range[1] + rank_input_range = (bucket_id_end - bucket_id_start) * bucket_size + rank_input_range = min(rank_input_range, high) + indices = torch.as_tensor( + np.random.choice(rank_input_range, replace=False, size=(rank_input_range,)), + dtype=torch.int64, + ) + indices += bucket_id_start * bucket_size + + # STEP 2: generate random indices with the given shape from the eligible indices above # 想要的输出形状 + idx = torch.randint(0, indices.numel(), size) + random_indices = indices[idx] + return random_indices + def generate_inputs_( self, B: int, @@ -139,6 +173,9 @@ def generate_inputs_( weights_precision: SparseType = SparseType.FP32, trigger_bounds_check: bool = False, mixed_B: bool = False, + is_kv_tbes: bool = False, + bucket_offsets: Optional[List[Tuple[int, int]]] = None, + bucket_sizes: Optional[List[int]] = None, ) -> Tuple[ List[torch.Tensor], List[torch.Tensor], @@ -158,12 +195,28 @@ def generate_inputs_( Bs_rank_feature, Bs = gen_mixed_B_batch_sizes(B, T) # Generate random indices and per sample weights - indices_list = [ - torch.randint( - low=0, high=Es[t] * (2 if trigger_bounds_check else 1), size=(b, L) - ).cuda() - for (b, t) in zip(Bs, feature_table_map) - ] + if is_kv_tbes: + assert len(bucket_offsets) == len(bucket_sizes) + assert len(bucket_offsets) <= len(feature_table_map) + indices_list = [ + self.generate_in_bucket_indices( + 0, + # pyre-ignore + bucket_offsets[t], + # pyre-ignore + bucket_sizes[t], + Es[t] * (2 if trigger_bounds_check else 1), + size=(b, L), + ).cuda() + for (b, t) in zip(Bs, feature_table_map) + ] + else: + indices_list = [ + torch.randint( + low=0, high=Es[t] * (2 if trigger_bounds_check else 1), size=(b, L) + ).cuda() + for (b, t) in zip(Bs, feature_table_map) + ] per_sample_weights_list = [torch.randn(size=(b, L)).cuda() for b in Bs] # Concat inputs for SSD TBE @@ -197,6 +250,178 @@ def generate_inputs_( batch_size_per_feature_per_rank, ) + def generate_kv_tbes( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + lr: float = 0.01, # from SSDTableBatchedEmbeddingBags + eps: float = 1.0e-8, # from SSDTableBatchedEmbeddingBags + ssd_shards: int = 1, # from SSDTableBatchedEmbeddingBags + optimizer: OptimType = OptimType.EXACT_ROWWISE_ADAGRAD, + cache_set_scale: float = 1.0, + # pyre-fixme[9]: pooling_mode has type `bool`; used as `PoolingMode`. + pooling_mode: bool = PoolingMode.SUM, + weights_precision: SparseType = SparseType.FP32, + output_dtype: SparseType = SparseType.FP32, + stochastic_rounding: bool = True, + share_table: bool = False, + prefetch_pipeline: bool = False, + zero_collision_tbe: bool = False, + backend_type: BackendType = BackendType.SSD, + num_buckets: int = 5, + ) -> Tuple[ + SSDTableBatchedEmbeddingBags, + List[torch.nn.EmbeddingBag], + List[int], + List[Tuple[int, int]], + List[int], + ]: + """ + Generate embedding modules (i,e., SSDTableBatchedEmbeddingBags and + torch.nn.EmbeddingBags) + + Idea in this UT, originally we have a ref_emb using EmbeddingBag/Embedding with the same size in emb, + to stimulate the lookup result from different pooling, weighted, etc..., and doing lookup on both + ref_emb and emb and compare the result. + + However when we test with kv zch embedding lookup, we are using virtual space which can not be preallocated + in ref_emb using EmbeddingBag/Embeddings. + + The little trick we do here is we still pre-allocate ref_emb with the given table size, + but when we create SSD TBE, we passed in the virtual table size. + + For example if the given table size is [100, 256], we have ref_emb with [100, 256], and SSD TBE with [2^25, 256] + the input id will always be in the range of [0, 100) + """ + import tempfile + + torch.manual_seed(42) + E = int(10**log_E) + virtual_E = int( + 2**18 + ) # relatively large for now given optimizer is still pre-allocated + D = D * 4 + Ds = [D] * T + Es = [E] * T + + if pooling_mode == PoolingMode.SUM: + mode = "sum" + do_pooling = True + elif pooling_mode == PoolingMode.MEAN: + mode = "mean" + do_pooling = True + elif pooling_mode == PoolingMode.NONE: + mode = "sum" + do_pooling = False + else: + # This proves that we have exhaustively checked all PoolingModes + raise RuntimeError("Unknown PoolingMode!") + + # Generate torch EmbeddingBag + # in kv tbes, we still maintain a small emb in EmbeddingBag or Embedding as a reference for expected outcome, + # but the virual space passed into TBE will be super large, e.g. 2^50 + # NOTE we will use a relative large virtual size for now, given that optimizer is still pre-allocated + if do_pooling: + emb_ref = [ + torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True).cuda() + for (E, D) in zip(Es, Ds) + ] + else: + emb_ref = [ + torch.nn.Embedding(E, D, sparse=True).cuda() for (E, D) in zip(Es, Ds) + ] + + # Cast type + if weights_precision == SparseType.FP16: + emb_ref = [emb.half() for emb in emb_ref] + + # Init weights + [emb.weight.data.uniform_(-2.0, 2.0) for emb in emb_ref] + + # Construct feature_table_map + feature_table_map = list(range(T)) + table_to_replicate = -1 + if share_table: + # autograd with shared embedding only works for exact + table_to_replicate = T // 2 + # pyre-ignore + feature_table_map.insert(table_to_replicate, table_to_replicate) + emb_ref.insert(table_to_replicate, emb_ref[table_to_replicate]) + + cache_sets = max(int(max(T * B * L, 1) * cache_set_scale), 1) + + # Generate TBE SSD + bucket_sizes = [] + bucket_offsets = [] + for _ in Es: + bucket_sizes.append(math.ceil(virtual_E / num_buckets)) + bucket_start = ( + 0 # since ref_emb is dense format, we need to start from 0th bucket + ) + bucket_end = min(math.ceil(num_buckets / KV_WORLD_SIZE), num_buckets) + bucket_offsets.append((bucket_start, bucket_end)) + emb = SSDTableBatchedEmbeddingBags( + embedding_specs=[(virtual_E, D) for D in Ds], + feature_table_map=feature_table_map, + ssd_storage_directory=tempfile.mkdtemp(), + cache_sets=cache_sets, + ssd_uniform_init_lower=-0.1, + ssd_uniform_init_upper=0.1, + learning_rate=lr, + eps=eps, + ssd_rocksdb_shards=ssd_shards, + optimizer=optimizer, + pooling_mode=pooling_mode, + weights_precision=weights_precision, + output_dtype=output_dtype, + stochastic_rounding=stochastic_rounding, + prefetch_pipeline=prefetch_pipeline, + bounds_check_mode=BoundsCheckMode.WARNING, + l2_cache_size=8, + zero_collision_tbe=zero_collision_tbe, + backend_type=backend_type, + bucket_offsets=bucket_offsets, + bucket_sizes=bucket_sizes, + ).cuda() + + self.assertTrue(emb.ssd_db.is_auto_compaction_enabled()) + + # By doing the check for ssd_db being None below, we also access the getter property of ssd_db, which will + # force the synchronization of lazy_init_thread, and then reset it to None. + if emb.ssd_db is not None: + self.assertIsNone(emb.lazy_init_thread) + + # A list to keep the CPU tensor alive until `set` (called inside + # `set_cuda`) is complete. Note that `set_cuda` is non-blocking + # asynchronous + emb_ref_cpu = [] + + # Initialize TBE SSD weights + for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map): + emb_ref_ = emb_ref[f].weight.clone().detach().cpu() + emb.ssd_db.set_cuda( + torch.arange(t * virtual_E, t * virtual_E + E).to(torch.int64), + emb_ref_, + torch.as_tensor([E]), + t, + ) + emb_ref_cpu.append(emb_ref_) + + # Ensure that `set` (invoked by `set_cuda`) is done + torch.cuda.synchronize() + + # Convert back to float (to make sure that accumulation is done + # in FP32 -- like TBE) + if weights_precision == SparseType.FP16: + emb_ref = [emb.float() for emb in emb_ref] + + # pyre-fixme[7] + return emb, emb_ref, Es, bucket_offsets, bucket_sizes + def generate_ssd_tbes( self, T: int, @@ -659,7 +884,7 @@ def test_ssd_backward_adagrad( ) # Compare optimizer states - split_optimizer_states = [s for (s,) in emb.debug_split_optimizer_states()] + split_optimizer_states = [s for (s, _, _) in emb.debug_split_optimizer_states()] for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map): # pyre-fixme[16]: Optional type has no attribute `float`. ref_optimizer_state = emb_ref[f].weight.grad.float().to_dense().pow(2) @@ -801,11 +1026,11 @@ def test_ssd_emb_state_dict( else 1.0e-2 ) - split_optimizer_states = [s for (s,) in emb.debug_split_optimizer_states()] + split_optimizer_states = [s for (s, _, _) in emb.debug_split_optimizer_states()] emb.flush() # Compare emb state dict with expected values from nn.EmbeddingBag - emb_state_dict = emb.split_embedding_weights(no_snapshot=False) + emb_state_dict, _, _ = emb.split_embedding_weights(no_snapshot=False) for feature_index, table_index in self.get_physical_table_arg_indices_( emb.feature_table_map ): @@ -886,7 +1111,7 @@ def execute_ssd_cache_pipeline_( # noqa C901 ) optimizer_states_ref = [ - s.clone().float() for (s,) in emb.debug_split_optimizer_states() + s.clone().float() for (s, _, _) in emb.debug_split_optimizer_states() ] Es = [emb.embedding_specs[t][0] for t in range(T)] @@ -1032,7 +1257,9 @@ def _prefetch(b_it: int) -> int: emb.flush() # Compare optimizer states - split_optimizer_states = [s for (s,) in emb.debug_split_optimizer_states()] + split_optimizer_states = [ + s for (s, _, _) in emb.debug_split_optimizer_states() + ] for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map): optim_state_r = optimizer_states_ref[t] optim_state_t = split_optimizer_states[t] @@ -1220,3 +1447,299 @@ def test_ssd_cache_pipeline_between_fwd_bwd(self, **kwargs: Any): flush_location=None, **kwargs, ) + + @given( + **default_st, + num_buckets=st.integers(min_value=10, max_value=15), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_kv_db_forward( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, + weights_precision: SparseType, + output_dtype: SparseType, + share_table: bool, + trigger_bounds_check: bool, + mixed_B: bool, + num_buckets: int, + ) -> None: + trigger_bounds_check = False # don't stimulate boundary check cases + assume(not weighted or pooling_mode == PoolingMode.SUM) + assume(not mixed_B or pooling_mode != PoolingMode.NONE) + + # Generate embedding modules + ( + emb, + emb_ref, + Es, + bucket_offsets, + bucket_sizes, + ) = self.generate_kv_tbes( + T, + D, + B, + log_E, + L, + weighted, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, + weights_precision=weights_precision, + output_dtype=output_dtype, + share_table=share_table, + zero_collision_tbe=True, + num_buckets=num_buckets, + ) + + # Generate inputs + ( + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank, + ) = self.generate_inputs_( + B, + L, + Es, + emb.feature_table_map, + weights_precision=weights_precision, + trigger_bounds_check=trigger_bounds_check, + mixed_B=mixed_B, + bucket_offsets=bucket_offsets, + bucket_sizes=bucket_sizes, + is_kv_tbes=True, + ) + + # Execute forward + self.execute_ssd_forward_( + emb, + emb_ref, + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + B, + L, + weighted, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + + @given( + **default_st, + num_buckets=st.integers(min_value=10, max_value=15), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_kv_emb_state_dict( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, + weights_precision: SparseType, + output_dtype: SparseType, + share_table: bool, + trigger_bounds_check: bool, + mixed_B: bool, + num_buckets: int, + ) -> None: + # Constants + lr = 0.5 + eps = 0.2 + ssd_shards = 2 + + trigger_bounds_check = False # don't stimulate boundary check cases + assume(not weighted or pooling_mode == PoolingMode.SUM) + assume(not mixed_B or pooling_mode != PoolingMode.NONE) + + # Generate embedding modules and inputs + ( + emb, + emb_ref, + Es, + bucket_offsets, + bucket_sizes, + ) = self.generate_kv_tbes( + T, + D, + B, + log_E, + L, + weighted, + lr=lr, + eps=eps, + ssd_shards=ssd_shards, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, + weights_precision=weights_precision, + output_dtype=output_dtype, + share_table=share_table, + zero_collision_tbe=True, + num_buckets=num_buckets, + ) + + # Generate inputs + ( + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank, + ) = self.generate_inputs_( + B, + L, + Es, + emb.feature_table_map, + weights_precision=weights_precision, + trigger_bounds_check=trigger_bounds_check, + mixed_B=mixed_B, + bucket_offsets=bucket_offsets, + bucket_sizes=bucket_sizes, + is_kv_tbes=True, + ) + + # Execute forward + output_ref_list, output = self.execute_ssd_forward_( + emb, + emb_ref, + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + B, + L, + weighted, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + + # Generate output gradient + output_grad_list = [torch.randn_like(out) for out in output_ref_list] + + # Execute torch EmbeddingBag backward + [out.backward(grad) for (out, grad) in zip(output_ref_list, output_grad_list)] + if batch_size_per_feature_per_rank is not None: + grad_test = self.concat_ref_tensors_vbe( + output_grad_list, batch_size_per_feature_per_rank + ) + else: + grad_test = self.concat_ref_tensors( + output_grad_list, + pooling_mode != PoolingMode.NONE, # do_pooling + B, + D * 4, + ) + + # Execute TBE SSD backward + output.backward(grad_test) + + tolerance = ( + 1.0e-4 + if weights_precision == SparseType.FP32 and output_dtype == SparseType.FP32 + else 1.0e-2 + ) + + emb.flush() + + split_optimizer_states = [] + table_input_id_range = [] + for s, input_id_start, input_id_end in emb.debug_split_optimizer_states(): + split_optimizer_states.append(s) + # the ref_emb might contains ids out of the bucket input range + table_input_id_range.append((input_id_start, input_id_end)) + # since we use ref_emb in dense format, the rows start from id 0 + self.assertEqual(input_id_start, 0) + + # Compare optimizer states + for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map): + # pyre-fixme[16]: Optional type has no attribute `float`. + ref_optimizer_state = emb_ref[f].weight.grad.float().to_dense().pow(2) + torch.testing.assert_close( + split_optimizer_states[t].float(), + ref_optimizer_state.mean(dim=1)[ + table_input_id_range[t][0] : min( + table_input_id_range[t][1], emb_ref[f].weight.size(0) + ) + ], + atol=tolerance, + rtol=tolerance, + ) + + # Compare emb state dict with expected values from nn.EmbeddingBag + emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list = ( + emb.split_embedding_weights(no_snapshot=False, should_flush=True) + ) + for feature_index, table_index in self.get_physical_table_arg_indices_( + emb.feature_table_map + ): + ################################################################# + ## validate bucket_asc_ids_list and num_active_id_per_bucket_list + ################################################################# + bucket_asc_id = bucket_asc_ids_list[table_index] + num_active_id_per_bucket = num_active_id_per_bucket_list[table_index] + + bucket_id_start = bucket_offsets[table_index][0] + bucket_id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + num_active_id_per_bucket.view(-1) + ) + for bucket_idx, id_count in enumerate(num_active_id_per_bucket): + bucket_id = bucket_idx + bucket_id_start + active_id_cnt = 0 + for idx in range( + bucket_id_offsets[bucket_idx], + bucket_id_offsets[bucket_idx + 1], + ): + # for chunk-based hashing + self.assertEqual( + bucket_id, bucket_asc_id[idx] // bucket_sizes[table_index] + ) + active_id_cnt += 1 + self.assertEqual(active_id_cnt, id_count) + + ###################### + ## validate embeddings + ###################### + id_range_start = table_input_id_range[table_index][ + 0 + ] # should be 0 because ref_emb is preallocated + id_range_end = min(table_input_id_range[table_index][1], Es[table_index]) + emb_r = emb_ref[feature_index] + self.assertLess(table_index, len(emb_state_dict_list)) + new_ref_weight = torch.addcdiv( + emb_r.weight.float()[id_range_start:id_range_end,], + value=-lr, + tensor1=emb_r.weight.grad.float().to_dense()[id_range_start:id_range_end,], # pyre-ignore[16] + tensor2=split_optimizer_states[table_index] + .float() + .sqrt_() + .add_(eps) + .view( + id_range_end - id_range_start, + 1, + ), + ).cpu() + + emb_w = ( + emb_state_dict_list[table_index] + .narrow(0, 0, id_range_end - id_range_start) + .float() + ) + torch.testing.assert_close( + emb_w, + new_ref_weight, + atol=tolerance, + rtol=tolerance, + )