From e4eb029d9fcd8874fc53e8c8eefb82dfe6ce2ed3 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Tue, 17 Dec 2024 11:34:55 +0000 Subject: [PATCH] 2024-12-17 nightly release (e6e4f6ce33bcfc933d5045d6ad1535db67fee99a) --- .../distributed/benchmark/benchmark_train.py | 1 - torchrec/distributed/comm_ops.py | 109 ------------------ .../distributed/embedding_dim_bucketer.py | 5 +- torchrec/distributed/embedding_lookup.py | 33 ------ torchrec/distributed/embeddingbag.py | 4 +- torchrec/distributed/fused_params.py | 12 +- torchrec/distributed/shard.py | 1 - torchrec/distributed/sharding/cw_sharding.py | 5 +- .../distributed/sharding/grid_sharding.py | 5 +- .../sharding/rw_sequence_sharding.py | 1 - torchrec/distributed/sharding/rw_sharding.py | 4 +- .../distributed/sharding/twrw_sharding.py | 4 +- torchrec/distributed/tensor_sharding.py | 2 +- .../distributed/test_utils/test_sharding.py | 2 - .../distributed/tests/test_2d_sharding.py | 2 +- torchrec/distributed/tests/test_comm.py | 78 ------------- .../distributed/tests/test_infer_shardings.py | 2 - .../tests/test_quant_model_parallel.py | 2 +- ...est_sequence_model_parallel_single_rank.py | 11 +- .../train_pipeline/train_pipelines.py | 3 - torchrec/distributed/types.py | 6 + torchrec/inference/client.py | 1 - torchrec/inference/tests/test_inference.py | 1 - torchrec/ir/types.py | 2 +- 24 files changed, 25 insertions(+), 271 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_train.py b/torchrec/distributed/benchmark/benchmark_train.py index d8cf35d00..15ea780f2 100644 --- a/torchrec/distributed/benchmark/benchmark_train.py +++ b/torchrec/distributed/benchmark/benchmark_train.py @@ -10,7 +10,6 @@ #!/usr/bin/env python3 import argparse -import copy import logging import os import time diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index cffcdecda..856f50d10 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -739,115 +739,6 @@ def all2all_sequence_sync( return sharded_output_embeddings.view(-1, D) -def alltoallv( - inputs: List[Tensor], - out_split: Optional[List[int]] = None, - per_rank_split_lengths: Optional[List[int]] = None, - group: Optional[dist.ProcessGroup] = None, - codecs: Optional[QuantizedCommCodecs] = None, -) -> Awaitable[List[Tensor]]: - """ - Performs `alltoallv` operation for a list of input embeddings. Each process scatters - the list to all processes in the group. - - Args: - inputs (List[Tensor]): list of tensors to scatter, one per rank. The tensors in - the list usually have different lengths. - out_split (Optional[List[int]]): output split sizes (or dim_sum_per_rank), if - not specified, we will use `per_rank_split_lengths` to construct a output - split with the assumption that all the embs have the same dimension. - per_rank_split_lengths (Optional[List[int]]): split lengths per rank. If not - specified, the `out_split` must be specified. - group (Optional[dist.ProcessGroup]): the process group to work on. If None, the - default process group will be used. - codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. - - Returns: - Awaitable[List[Tensor]]: async work handle (`Awaitable`), which can be `wait()` later to get the resulting list of tensors. - - .. warning:: - `alltoallv` is experimental and subject to change. - """ - - if group is None: - group = dist.distributed_c10d._get_default_group() - - world_size: int = group.size() - my_rank: int = group.rank() - - B_global = inputs[0].size(0) - - D_local_list = [] - for e in inputs: - D_local_list.append(e.size()[1]) - - B_local, B_local_list = _get_split_lengths_by_len(world_size, my_rank, B_global) - - if out_split is not None: - dims_sum_per_rank = out_split - elif per_rank_split_lengths is not None: - # all the embs have the same dimension - dims_sum_per_rank = [] - for s in per_rank_split_lengths: - dims_sum_per_rank.append(s * D_local_list[0]) - else: - raise RuntimeError("Need to specify either out_split or per_rank_split_lengths") - - a2ai = All2AllVInfo( - dims_sum_per_rank=dims_sum_per_rank, - B_local=B_local, - B_local_list=B_local_list, - D_local_list=D_local_list, - B_global=B_global, - codecs=codecs, - ) - - if get_use_sync_collectives(): - return NoWait(all2allv_sync(group, a2ai, inputs)) - - myreq = Request(group, device=inputs[0].device) - All2Allv_Req.apply(group, myreq, a2ai, inputs) - - return myreq - - -def all2allv_sync( - pg: dist.ProcessGroup, - a2ai: All2AllVInfo, - inputs: List[Tensor], -) -> List[Tensor]: - input_split_sizes = [] - sum_D_local_list = sum(a2ai.D_local_list) - for m in a2ai.B_local_list: - input_split_sizes.append(m * sum_D_local_list) - - output_split_sizes = [] - for e in a2ai.dims_sum_per_rank: - output_split_sizes.append(a2ai.B_local * e) - - input = torch.cat(inputs, dim=1).view([-1]) - if a2ai.codecs is not None: - input = a2ai.codecs.forward.encode(input) - - with record_function("## alltoallv_bwd_single ##"): - output = torch.ops.torchrec.all_to_all_single( - input, - output_split_sizes, - input_split_sizes, - pg_name(pg), - pg.size(), - get_gradient_division(), - ) - - if a2ai.codecs is not None: - output = a2ai.codecs.forward.decode(output) - - outputs = [] - for out in output.split(output_split_sizes): - outputs.append(out.view([a2ai.B_local, -1])) - return outputs - - def reduce_scatter_pooled( inputs: List[Tensor], group: Optional[dist.ProcessGroup] = None, diff --git a/torchrec/distributed/embedding_dim_bucketer.py b/torchrec/distributed/embedding_dim_bucketer.py index 21283a445..ef2f58b15 100644 --- a/torchrec/distributed/embedding_dim_bucketer.py +++ b/torchrec/distributed/embedding_dim_bucketer.py @@ -10,10 +10,7 @@ from enum import Enum, unique from typing import Dict, List -from torchrec.distributed.embedding_types import ( - EmbeddingComputeKernel, - ShardedEmbeddingTable, -) +from torchrec.distributed.embedding_types import ShardedEmbeddingTable from torchrec.modules.embedding_configs import DATA_TYPE_NUM_BITS, DataType diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 050fde40e..1f1645335 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -1101,36 +1101,3 @@ def get_tbes_to_register( self, ) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: return get_tbes_to_register_from_iterable(self._embedding_lookups_per_rank) - - -class InferCPUGroupedEmbeddingsLookup( - InferGroupedLookupMixin, - BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]], - TBEToRegisterMixIn, -): - def __init__( - self, - grouped_configs_per_rank: List[List[GroupedEmbeddingConfig]], - world_size: int, - fused_params: Optional[Dict[str, Any]] = None, - device: Optional[torch.device] = None, - ) -> None: - super().__init__() - self._embedding_lookups_per_rank: List[MetaInferGroupedEmbeddingsLookup] = [] - - device_type: str = "cpu" if device is None else device.type - for rank in range(world_size): - self._embedding_lookups_per_rank.append( - MetaInferGroupedEmbeddingsLookup( - grouped_configs=grouped_configs_per_rank[rank], - # syntax for torchscript - # pyre-fixme[20]: Argument `index` expected. - device=torch.device(type=device_type), - fused_params=fused_params, - ) - ) - - def get_tbes_to_register( - self, - ) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: - return get_tbes_to_register_from_iterable(self._embedding_lookups_per_rank) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 84e033a31..a7ac5c972 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -610,9 +610,7 @@ def __init__( ) self._env = env # output parameters as DTensor in state dict - self._output_dtensor: bool = ( - fused_params.get("output_dtensor", False) if fused_params else False - ) + self._output_dtensor: bool = env.output_dtensor sharding_type_to_sharding_infos = create_sharding_infos_by_sharding( module, diff --git a/torchrec/distributed/fused_params.py b/torchrec/distributed/fused_params.py index 26af33938..71b6b4786 100644 --- a/torchrec/distributed/fused_params.py +++ b/torchrec/distributed/fused_params.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, Optional import torch @@ -55,7 +55,7 @@ def is_fused_param_register_tbe(fused_params: Optional[Dict[str, Any]]) -> bool: def get_fused_param_tbe_row_alignment( - fused_params: Optional[Dict[str, Any]] + fused_params: Optional[Dict[str, Any]], ) -> Optional[int]: if fused_params is None or FUSED_PARAM_TBE_ROW_ALIGNMENT not in fused_params: return None @@ -64,7 +64,7 @@ def get_fused_param_tbe_row_alignment( def fused_param_bounds_check_mode( - fused_params: Optional[Dict[str, Any]] + fused_params: Optional[Dict[str, Any]], ) -> Optional[BoundsCheckMode]: if fused_params is None or FUSED_PARAM_BOUNDS_CHECK_MODE not in fused_params: return None @@ -73,7 +73,7 @@ def fused_param_bounds_check_mode( def fused_param_lengths_to_offsets_lookup( - fused_params: Optional[Dict[str, Any]] + fused_params: Optional[Dict[str, Any]], ) -> bool: if ( fused_params is None @@ -85,7 +85,7 @@ def fused_param_lengths_to_offsets_lookup( def is_fused_param_quant_state_dict_split_scale_bias( - fused_params: Optional[Dict[str, Any]] + fused_params: Optional[Dict[str, Any]], ) -> bool: return ( fused_params @@ -95,7 +95,7 @@ def is_fused_param_quant_state_dict_split_scale_bias( def tbe_fused_params( - fused_params: Optional[Dict[str, Any]] + fused_params: Optional[Dict[str, Any]], ) -> Optional[Dict[str, Any]]: if not fused_params: return None diff --git a/torchrec/distributed/shard.py b/torchrec/distributed/shard.py index 4c44ae221..a755d2c8b 100644 --- a/torchrec/distributed/shard.py +++ b/torchrec/distributed/shard.py @@ -29,7 +29,6 @@ ) from torchrec.distributed.utils import init_parameters from torchrec.modules.utils import reset_module_states_post_sharding -from torchrec.types import CacheMixin def _join_module_path(path: str, name: str) -> str: diff --git a/torchrec/distributed/sharding/cw_sharding.py b/torchrec/distributed/sharding/cw_sharding.py index 940f1a0ca..aa4fafa2b 100644 --- a/torchrec/distributed/sharding/cw_sharding.py +++ b/torchrec/distributed/sharding/cw_sharding.py @@ -32,7 +32,6 @@ DTensorMetadata, EmbeddingComputeKernel, InputDistOutputs, - KJTList, ShardedEmbeddingTable, ) from torchrec.distributed.sharding.tw_sharding import ( @@ -170,7 +169,7 @@ def _shard( ) dtensor_metadata = None - if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + if self._env.output_dtensor: dtensor_metadata = DTensorMetadata( mesh=self._env.device_mesh, placements=( @@ -187,8 +186,6 @@ def _shard( ), stride=info.param.stride(), ) - # to not pass onto TBE - info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] # pyre-fixme [6] for i, rank in enumerate(info.param_sharding.ranks): diff --git a/torchrec/distributed/sharding/grid_sharding.py b/torchrec/distributed/sharding/grid_sharding.py index a0da146ea..c5cc31e87 100644 --- a/torchrec/distributed/sharding/grid_sharding.py +++ b/torchrec/distributed/sharding/grid_sharding.py @@ -232,7 +232,7 @@ def _shard( ) dtensor_metadata = None - if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + if self._env.output_dtensor: placements = ( (Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),) ) @@ -246,9 +246,6 @@ def _shard( stride=info.param.stride(), ) - # to not pass onto TBE - info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] - # Expectation is planner CW shards across a node, so each CW shard will have local_size number of row shards # pyre-fixme [6] for i, rank in enumerate(info.param_sharding.ranks): diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index 38b68c3ed..1d9fb71d5 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -17,7 +17,6 @@ ) from torchrec.distributed.embedding_lookup import ( GroupedEmbeddingsLookup, - InferCPUGroupedEmbeddingsLookup, InferGroupedEmbeddingsLookup, ) from torchrec.distributed.embedding_sharding import ( diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index ff5d764ea..7111aa311 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -179,7 +179,7 @@ def _shard( ) dtensor_metadata = None - if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + if self._env.output_dtensor: placements = ( (Replicate(), Shard(0)) if self._is_2D_parallel else (Shard(0),) ) @@ -197,8 +197,6 @@ def _shard( ), stride=info.param.stride(), ) - # to not pass onto TBE - info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] for rank in range(self._world_size): tables_per_rank[rank].append( diff --git a/torchrec/distributed/sharding/twrw_sharding.py b/torchrec/distributed/sharding/twrw_sharding.py index ae8a0c782..7541ca26a 100644 --- a/torchrec/distributed/sharding/twrw_sharding.py +++ b/torchrec/distributed/sharding/twrw_sharding.py @@ -164,7 +164,7 @@ def _shard( ) dtensor_metadata = None - if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + if self._env.output_dtensor: placements = (Shard(0),) dtensor_metadata = DTensorMetadata( mesh=self._env.device_mesh, @@ -175,8 +175,6 @@ def _shard( ), stride=info.param.stride(), ) - # to not pass onto TBE - info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] for rank in range( table_node * local_size, diff --git a/torchrec/distributed/tensor_sharding.py b/torchrec/distributed/tensor_sharding.py index 9c211c5aa..6a7ca0715 100644 --- a/torchrec/distributed/tensor_sharding.py +++ b/torchrec/distributed/tensor_sharding.py @@ -11,7 +11,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import cast, Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import distributed as dist diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index dbd8f1007..4b0aedfd6 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -41,7 +41,6 @@ ) from torchrec.distributed.types import ( EmbeddingModuleShardingPlan, - EnumerableShardingSpec, ModuleSharder, ShardedTensor, ShardingEnv, @@ -288,7 +287,6 @@ def sharding_single_rank_test( world_size_2D: Optional[int] = None, node_group_size: Optional[int] = None, ) -> None: - with MultiProcessContext(rank, world_size, backend, local_size) as ctx: # Generate model & inputs. (global_model, inputs) = gen_model_and_input( diff --git a/torchrec/distributed/tests/test_2d_sharding.py b/torchrec/distributed/tests/test_2d_sharding.py index c76e8d4cf..4215d5dbc 100644 --- a/torchrec/distributed/tests/test_2d_sharding.py +++ b/torchrec/distributed/tests/test_2d_sharding.py @@ -466,8 +466,8 @@ def test_sharding_twrw_2D( self._test_sharding( world_size=self.WORLD_SIZE, - local_size=self.WORLD_SIZE_2D // 2, world_size_2D=self.WORLD_SIZE_2D, + node_group_size=self.WORLD_SIZE // 4, sharders=[ cast( ModuleSharder[nn.Module], diff --git a/torchrec/distributed/tests/test_comm.py b/torchrec/distributed/tests/test_comm.py index 02dbf02f3..d110e9740 100644 --- a/torchrec/distributed/tests/test_comm.py +++ b/torchrec/distributed/tests/test_comm.py @@ -204,84 +204,6 @@ def _run_multi_process_test( p.join() self.assertEqual(0, p.exitcode) - @classmethod - def _test_alltoallv( - cls, - rank: int, - world_size: int, - backend: str, - compile_config: _CompileConfig, - specify_pg: bool, - ) -> None: - dist.init_process_group(rank=rank, world_size=world_size, backend=backend) - pg = GroupMember.WORLD - assert pg is not None - - device = torch.device(f"cuda:{rank}") - - torch.cuda.set_device(device) - - B_global = 10 - D0 = 8 - D1 = 9 - - input_embedding0 = torch.rand( - (B_global, D0), - device=device, - requires_grad=True, - ) - input_embedding1 = torch.rand( - (B_global, D1), - device=device, - requires_grad=True, - ) - - input_embeddings = [input_embedding0, input_embedding1] - out_split = [17, 17] - - # pyre-ignore - def fn(*args, **kwargs) -> List[torch.Tensor]: - return comm_ops.alltoallv(*args, **kwargs).wait() - - fn_transform = compile_config_to_fn_transform(compile_config) - - with unittest.mock.patch( - "torch._dynamo.config.skip_torchrec", - False, - ): - v_embs_out = fn_transform(fn)( - input_embeddings, out_split=out_split, group=pg if specify_pg else None - ) - - res = torch.cat(v_embs_out, dim=1).cpu() - assert tuple(res.size()) == (5, 34) - dist.destroy_process_group() - - @unittest.skipIf( - torch.cuda.device_count() < 2, "Need at least two ranks to run this test" - ) - # pyre-ignore - @given( - specify_pg=st.sampled_from([True]), - test_compiled_with_noncompiled_ranks=st.sampled_from([False, True]), - ) - @settings(deadline=None) - def test_alltoallv( - self, - specify_pg: bool, - test_compiled_with_noncompiled_ranks: bool, - ) -> None: - self._run_multi_process_test( - world_size=self.WORLD_SIZE, - backend="nccl", - # pyre-ignore [6] - callable=self._test_alltoallv, - compile_config=_CompileConfig( - test_compiled_with_noncompiled_ranks=test_compiled_with_noncompiled_ranks - ), - specify_pg=specify_pg, - ) - @classmethod def _test_alltoall_sequence( cls, diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index 9bf975d3e..83b4649ee 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -37,10 +37,8 @@ ) from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder from torchrec.distributed.quant_embeddingbag import ( - QuantEmbeddingBagCollection, QuantEmbeddingBagCollectionSharder, QuantFeatureProcessedEmbeddingBagCollectionSharder, - ShardedQuantEmbeddingBagCollection, ) from torchrec.distributed.quant_state import sharded_tbes_weights_spec, WeightSpec from torchrec.distributed.shard import _shard_modules diff --git a/torchrec/distributed/tests/test_quant_model_parallel.py b/torchrec/distributed/tests/test_quant_model_parallel.py index 7dc4746de..131b9d3d6 100644 --- a/torchrec/distributed/tests/test_quant_model_parallel.py +++ b/torchrec/distributed/tests/test_quant_model_parallel.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Any, cast, Dict, List, Optional, Tuple +from typing import cast, Dict, Optional, Tuple import hypothesis.strategies as st import torch diff --git a/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py b/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py index 8e3699825..26ca8c55b 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel_single_rank.py @@ -9,18 +9,13 @@ import unittest -from typing import cast, Dict, List, Optional, OrderedDict, Tuple +from typing import cast, OrderedDict import hypothesis.strategies as st import torch from hypothesis import given, settings, Verbosity -from torch import distributed as dist, nn -from torchrec import distributed as trec_dist -from torchrec.distributed import DistributedModelParallel +from torch import nn from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.model_parallel import get_default_sharders -from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology -from torchrec.distributed.test_utils.test_model import ModelInput from torchrec.distributed.test_utils.test_model_parallel_base import ( ModelParallelSingleRankBase, ) @@ -28,7 +23,7 @@ TestEmbeddingCollectionSharder, TestSequenceSparseNN, ) -from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType +from torchrec.distributed.types import ModuleSharder, ShardingType from torchrec.modules.embedding_configs import DataType, EmbeddingConfig diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index e9189bd3f..d42a2e9ac 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -11,7 +11,6 @@ import contextlib import logging from collections import deque -from contextlib import contextmanager from dataclasses import dataclass from typing import ( Any, @@ -31,7 +30,6 @@ ) import torch -import torchrec.distributed.comm_ops from torch.autograd.profiler import record_function from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable from torchrec.distributed.model_parallel import ShardedModule @@ -760,7 +758,6 @@ def _grad_swap(self) -> None: param.grad = grad def _init_embedding_streams(self) -> None: - for _ in self._pipelined_modules: self._embedding_streams.append( (torch.get_device_module(self._device).Stream(priority=0)) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 44461752a..8085c415f 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -633,6 +633,7 @@ class KeyValueParams: gather_ssd_cache_stats: bool: whether enable ssd stats collection, std reporter and ods reporter report_interval: int: report interval in train iteration if gather_ssd_cache_stats is enabled ods_prefix: str: ods prefix for ods reporting + bulk_init_chunk_size: int: number of rows to insert into rocksdb in each chunk # Parameter Server (PS) Attributes ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses @@ -652,6 +653,7 @@ class KeyValueParams: l2_cache_size: Optional[int] = None # size in GB max_l1_cache_size: Optional[int] = None # size in MB enable_async_update: Optional[bool] = None + bulk_init_chunk_size: Optional[int] = None # number of rows # Parameter Server (PS) Attributes ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None @@ -676,6 +678,7 @@ def __hash__(self) -> int: self.l2_cache_size, self.max_l1_cache_size, self.enable_async_update, + self.bulk_init_chunk_size, ) ) @@ -813,6 +816,7 @@ def __init__( world_size: int, rank: int, pg: Optional[dist.ProcessGroup] = None, + output_dtensor: bool = False, ) -> None: self.world_size = world_size self.rank = rank @@ -825,6 +829,7 @@ def __init__( if pg else None ) + self.output_dtensor: bool = output_dtensor @classmethod def from_process_group(cls, pg: dist.ProcessGroup) -> "ShardingEnv": @@ -886,6 +891,7 @@ def __init__( self.sharding_pg: dist.ProcessGroup = sharding_pg self.device_mesh: DeviceMesh = device_mesh self.node_group_size: Optional[int] = node_group_size + self.output_dtensor: bool = True def num_sharding_groups(self) -> int: """ diff --git a/torchrec/inference/client.py b/torchrec/inference/client.py index 725338f46..50bdc09ea 100644 --- a/torchrec/inference/client.py +++ b/torchrec/inference/client.py @@ -11,7 +11,6 @@ import grpc import predictor_pb2, predictor_pb2_grpc import torch -from torch.utils.data import DataLoader from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES from torchrec.datasets.random import RandomRecDataset from torchrec.datasets.utils import Batch diff --git a/torchrec/inference/tests/test_inference.py b/torchrec/inference/tests/test_inference.py index 7ee04d9f0..5c4563185 100644 --- a/torchrec/inference/tests/test_inference.py +++ b/torchrec/inference/tests/test_inference.py @@ -14,7 +14,6 @@ import torch from fbgemm_gpu.split_embedding_configs import SparseType -from torch.fx import symbolic_trace from torchrec import PoolingType from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES from torchrec.distributed.fused_params import ( diff --git a/torchrec/ir/types.py b/torchrec/ir/types.py index caa7a35a5..7dc1695b9 100644 --- a/torchrec/ir/types.py +++ b/torchrec/ir/types.py @@ -10,7 +10,7 @@ #!/usr/bin/env python3 import abc -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import torch