diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index ad1e81dfbc..b301be3ae4 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -57,6 +57,12 @@ class KVZCHParams(NamedTuple): # the value indicates corresponding input space for each bucket id, e.g. 2^50 / total_num_buckets bucket_sizes: List[int] = [] + def validate(self) -> None: + assert len(self.bucket_offsets) == len(self.bucket_sizes), ( + "bucket_offsets and bucket_sizes must have the same length, " + f"actual {self.bucket_offsets} vs {self.bucket_sizes}" + ) + class BackendType(enum.IntEnum): SSD = 0 diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 23aba69ff5..64e2ef4569 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -29,9 +29,11 @@ ) from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( + BackendType, BoundsCheckMode, CacheAlgorithm, EmbeddingLocation, + KVZCHParams, PoolingMode, SplitState, ) @@ -152,6 +154,8 @@ def __init__( # number of rows will be decided by bulk_init_chunk_size / size_of_each_row bulk_init_chunk_size: int = 0, lazy_bulk_init_enabled: bool = False, + backend_type: BackendType = BackendType.SSD, + kv_zch_params: Optional[KVZCHParams] = None, ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() @@ -426,6 +430,12 @@ def __init__( ) # logging.info("DEBUG: weights_precision {}".format(weights_precision)) + # zero collision TBE configurations + self.kv_zch_params = kv_zch_params + self.backend_type = backend_type + if self.kv_zch_params: + self.kv_zch_params.validate() + # create tbe unique id using rank index | local tbe idx if tbe_unique_id == -1: SSDTableBatchedEmbeddingBags._local_instance_index += 1 @@ -442,7 +452,7 @@ def __init__( tbe_unique_id = SSDTableBatchedEmbeddingBags._local_instance_index self.tbe_unique_id = tbe_unique_id logging.info(f"tbe_unique_id: {tbe_unique_id}") - if not ps_hosts: + if self.backend_type == BackendType.SSD: logging.info( f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, enable_async_update:{enable_async_update}" f"passed_in_path={ssd_directory}, num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards}," @@ -485,11 +495,9 @@ def __init__( self._lazy_initialize_ssd_tbe() else: self._insert_all_kv() - else: - # pyre-fixme[4]: Attribute must be annotated. - # pyre-ignore[16] + elif self.backend_type == BackendType.PS: self._ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper( - [host[0] for host in ps_hosts], + [host[0] for host in ps_hosts], # pyre-ignore [host[1] for host in ps_hosts], tbe_unique_id, ( @@ -502,6 +510,9 @@ def __init__( l2_cache_size, self.max_D, ) + else: + raise AssertionError(f"Invalid backend type {self.backend_type}") + # pyre-fixme[20]: Argument `self` expected. (low_priority, high_priority) = torch.cuda.Stream.priority_range() # GPU stream for SSD cache eviction @@ -1718,23 +1729,53 @@ def forward( ) @torch.jit.ignore - def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor]]: + def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor, int, int]]: """ - Returns a list of states, split by table + Returns a list of optimizer states, table_input_id_start, table_input_id_end, split by table Testing only """ (rows, _) = zip(*self.embedding_specs) rows_cumsum = [0] + list(itertools.accumulate(rows)) - - return [ - ( - self.momentum1_dev.detach()[rows_cumsum[t] : rows_cumsum[t + 1]].view( - row - ), - ) - for t, row in enumerate(rows) - ] + if self.kv_zch_params: + opt_list = [] + table_offset = 0 + for t, row in enumerate(rows): + # pyre-ignore + bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[t] + # pyre-ignore + bucket_size = self.kv_zch_params.bucket_sizes[t] + table_input_id_start = bucket_id_start * bucket_size + table_offset + table_input_id_end = bucket_id_end * bucket_size + table_offset + + # TODO: this is a hack for preallocated optimizer, update this part once we have optimizer offloading + unlinearized_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot( + table_input_id_start, + table_input_id_end, + 0, # no need for table offest, as optimizer is preallocated using table offset + None, + ) + sorted_offsets, _ = torch.sort(unlinearized_id_tensor.view(-1)) + opt_list.append( + ( + self.momentum1_dev.detach()[sorted_offsets], + table_input_id_start - table_offset, + table_input_id_end - table_offset, + ) + ) + table_offset += row + return opt_list + else: + return [ + ( + self.momentum1_dev.detach()[ + rows_cumsum[t] : rows_cumsum[t + 1] + ].view(row), + -1, + -1, + ) + for t, row in enumerate(rows) + ] @torch.jit.export def debug_split_embedding_weights(self) -> List[torch.Tensor]: @@ -1786,7 +1827,11 @@ def split_embedding_weights( self, no_snapshot: bool = True, should_flush: bool = False, - ) -> List[PartiallyMaterializedTensor]: + ) -> Tuple[ # TODO: make this a NamedTuple for readability + List[PartiallyMaterializedTensor], + Optional[List[torch.Tensor]], + Optional[List[torch.Tensor]], + ]: """ This method is intended to be used by the checkpointing engine only. @@ -1800,35 +1845,86 @@ def split_embedding_weights( operation, only set to True when necessary. Returns: - a list of partially materialized tensors, each representing a table + tuples of 3 lists, each element corresponds to a logical table + 1st arg: partially materialized tensors, each representing a table + 2nd arg: input id sorted in bucket id ascending order + 3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start] + where for the i th element, we have i + bucket_id_start = global bucket id """ # Force device synchronize for now torch.cuda.synchronize() - # Create a snapshot - if no_snapshot: - snapshot_handle = None - else: - if should_flush: - # Flush L1 and L2 caches - self.flush() - snapshot_handle = self.ssd_db.create_snapshot() + snapshot_handle = None + if self.backend_type == BackendType.SSD: + # Create a rocksdb snapshot + if not no_snapshot: + if should_flush: + # Flush L1 and L2 caches + self.flush() + snapshot_handle = self.ssd_db.create_snapshot() + elif self.backend_type == BackendType.DRAM: + self.flush() dtype = self.weights_precision.as_dtype() - splits = [] + pmt_splits = [] + bucket_sorted_id_splits = [] if self.kv_zch_params else None + active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None + + table_offset = 0 + for i, (emb_height, emb_dim) in enumerate(self.embedding_specs): + bucket_ascending_id_tensor = None + bucket_t = None + if self.kv_zch_params: + bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i] + # pyre-ignore + bucket_size = self.kv_zch_params.bucket_sizes[i] + + # linearize with table offset + table_input_id_start = bucket_id_start * bucket_size + table_offset + table_input_id_end = bucket_id_end * bucket_size + table_offset + # 1. get all keys from backend for one table + unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot( + table_input_id_start, + table_input_id_end, + table_offset, + snapshot_handle, + ) + # 2. sorting keys in bucket ascending order + bucket_ascending_id_tensor, bucket_t = ( + torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor( + unordered_id_tensor, + 0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing + bucket_id_start, + bucket_id_end, + bucket_size, + ) + ) + # pyre-ignore + bucket_sorted_id_splits.append(bucket_ascending_id_tensor) + active_id_cnt_per_bucket_split.append(bucket_t) - row_offset = 0 - for emb_height, emb_dim in self.embedding_specs: tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( - shape=[emb_height, emb_dim], + shape=[ + ( + bucket_ascending_id_tensor.size(0) + if bucket_ascending_id_tensor is not None + else emb_height + ), + emb_dim, + ], dtype=dtype, - row_offset=row_offset, + row_offset=table_offset, snapshot_handle=snapshot_handle, ) # TODO add if else support in the future for dram integration. tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db) - row_offset += emb_height - splits.append(PartiallyMaterializedTensor(tensor_wrapper)) - return splits + table_offset += emb_height + pmt_splits.append( + PartiallyMaterializedTensor( + tensor_wrapper, + True if self.kv_zch_params else False, + ) + ) + return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split) @torch.jit.export def set_learning_rate(self, lr: float) -> None: diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index 7f7bd06fea..f094678873 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -659,7 +659,7 @@ def test_ssd_backward_adagrad( ) # Compare optimizer states - split_optimizer_states = [s for (s,) in emb.debug_split_optimizer_states()] + split_optimizer_states = [s for (s, _, _) in emb.debug_split_optimizer_states()] for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map): # pyre-fixme[16]: Optional type has no attribute `float`. ref_optimizer_state = emb_ref[f].weight.grad.float().to_dense().pow(2) @@ -801,11 +801,11 @@ def test_ssd_emb_state_dict( else 1.0e-2 ) - split_optimizer_states = [s for (s,) in emb.debug_split_optimizer_states()] + split_optimizer_states = [s for (s, _, _) in emb.debug_split_optimizer_states()] emb.flush() # Compare emb state dict with expected values from nn.EmbeddingBag - emb_state_dict = emb.split_embedding_weights(no_snapshot=False) + emb_state_dict, _, _ = emb.split_embedding_weights(no_snapshot=False) for feature_index, table_index in self.get_physical_table_arg_indices_( emb.feature_table_map ): @@ -886,7 +886,7 @@ def execute_ssd_cache_pipeline_( # noqa C901 ) optimizer_states_ref = [ - s.clone().float() for (s,) in emb.debug_split_optimizer_states() + s.clone().float() for (s, _, _) in emb.debug_split_optimizer_states() ] Es = [emb.embedding_specs[t][0] for t in range(T)] @@ -1032,7 +1032,9 @@ def _prefetch(b_it: int) -> int: emb.flush() # Compare optimizer states - split_optimizer_states = [s for (s,) in emb.debug_split_optimizer_states()] + split_optimizer_states = [ + s for (s, _, _) in emb.debug_split_optimizer_states() + ] for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map): optim_state_r = optimizer_states_ref[t] optim_state_t = split_optimizer_states[t]