diff --git a/fbgemm_gpu/cmake/tbe_sources.py b/fbgemm_gpu/cmake/tbe_sources.py index 2c0504c9a..a32be1eff 100644 --- a/fbgemm_gpu/cmake/tbe_sources.py +++ b/fbgemm_gpu/cmake/tbe_sources.py @@ -191,6 +191,9 @@ + [ "gen_embedding_backward_split_common_device_kernel.cuh", ] + + [ + "pt2_arg_utils.h", + ] ) gen_defused_optim_templates = [ @@ -502,7 +505,6 @@ for optimizer in COMMON_OPTIMIZERS + CPU_ONLY_OPTIMIZERS + GPU_ONLY_OPTIMIZERS for fstring in [ "lookup_{}.py", - "lookup_{}_pt2.py", ] ] + [ @@ -510,7 +512,6 @@ for optimizer in SSD_OPTIMIZERS for fstring in [ "lookup_{}_ssd.py", - "lookup_{}_ssd_pt2.py", ] ] + [ diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index c97714857..4dae106bf 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -197,7 +197,6 @@ def generate_backward_split_gpu(**kwargs: Any) -> None: ) for filename in [ f"lookup_{optimizer}{sdesc}.py", - f"lookup_{optimizer}{sdesc}_pt2.py", ]: template.write( filename, is_fbcode=args.is_fbcode, ssd=ssd, **kwargs @@ -331,6 +330,23 @@ def generate_rocm_backward_split(**kwargs: Any) -> None: }, ) + @staticmethod + def generate_backward_header( + aux_args: Dict[str, List[str]], aux_names: List[str] + ) -> None: + """ + Generate a header file that contains enum of argument order from the dict + + Parameters: + aux_args (Dict[str, List[str]]): a dict containing a list of arguments + aux_names (List[str]): names of the argument types (e.g. aux_tensor, aux_int, etc.) + Return: + None + """ + # Generate backward header for PT2 Autograd + template = CodeTemplate.load("training/pt2/pt2_arg_utils_template.h") + template.write(f"pt2_arg_utils.h", aux_args=aux_args, aux_names=aux_names) + @staticmethod def generate_python_sources( all_optimizers: List[str], ssd_optimizers: List[str] @@ -375,6 +391,40 @@ def generate() -> None: "actions_count", ] + aux_names = ["aux_tensor", "aux_int", "aux_float", "aux_bool"] + # This is a dict of auxilary arguments used in TBE PT2 interface where the aux + # arguments of a type are packed into a list for that type. This dict maintains the + # order of the arguments of each type. + aux_args: Dict[str, List[str]] = { + "aux_tensor": [ + "B_offsets", # 0 + "vbe_output_offsets_feature_rank", # 1 + "vbe_B_offsets_rank_per_feature", # 2 + "lxu_cache_locations", # 3 + "uvm_cache_stats", # 4 + "prev_iter_dev", # 5 + ], + "aux_int": [ + "iter", # 0 + ], + "aux_float": [ + "gwd_lower_bound", # 0 + "max_gradient", # 1 + ], + "aux_bool": [ + "is_experimental_tbe", # 0 + "use_uniq_cache_locations_bwd", # 1 + "use_homogeneous_placements", # 2 + "apply_global_weight_decay", # 3 + "gradient_clipping", # 4 + "stochastic_rounding", # 5 + "mixed_D", # 6 + ], + } + assert ( + list(aux_args.keys()) == aux_names + ), f"{aux_names} must match {aux_args.keys()}" + all_optimizers = [] ssd_optimizers = [] @@ -399,6 +449,9 @@ def generate() -> None: BackwardSplitGenerator.generate_backward_grad() BackwardSplitGenerator.generate_backward_indices() + # Generate headers for backwards + BackwardSplitGenerator.generate_backward_header(aux_args, aux_names) + BackwardSplitGenerator.generate_python_sources(all_optimizers, ssd_optimizers) diff --git a/fbgemm_gpu/codegen/genscript/optimizer_args.py b/fbgemm_gpu/codegen/genscript/optimizer_args.py index 669b1a44f..f432895cf 100644 --- a/fbgemm_gpu/codegen/genscript/optimizer_args.py +++ b/fbgemm_gpu/codegen/genscript/optimizer_args.py @@ -7,6 +7,7 @@ # pyre-strict # pyre-ignore-all-errors[29] +# pyre-ignore-all-errors[53] # flake8: noqa F401 @@ -205,6 +206,40 @@ def schema_bool_arg(name: str, default: bool = False) -> str: return f"bool {name} = {default}" +def list_arg(ty: str) -> str: + """ + Returns a C++ argument for a list of optimizer arguments the given type. + + Parameters: + ty (str) - type of the list e.g., "int", "float", "tensor" + Returns: + C++ arguemnt for a list of the given type e.g., for a list of int returns "std::vector optim_int", + """ + return { + "tensor": "std::vector> optim_tensor", + "int": "std::vector optim_int", + "float": "std::vector optim_float", + "bool": "c10::List optim_bool", + }[ty] + + +def schema_list_arg(ty: str) -> str: + """ + Returns a C++ schema for a list of optimizer arguments the given type. + + Parameters: + ty (str) - type of the list e.g., "int", "float", "tensor" + Returns: + C++ arguemnt for a list of the given type e.g., for a list of int returns "int[] optim_int", + """ + return { + "tensor": "Tensor?[] optim_tensor", + "int": "int[] optim_int", + "float": "float[] optim_float", + "bool": "bool[] optim_bool", + }[ty] + + def optional_tensor_arg(name: str) -> str: return f"std::optional {name} = std::nullopt" @@ -230,7 +265,6 @@ def schema_optional_tensorlist_arg(name: str) -> str: def make_kernel_arg( - # pyre-fixme[11]: Annotation `ArgType` is not defined as a type. ty: ArgType, name: str, default: Union[int, float, None], @@ -505,6 +539,10 @@ class PT2ArgsSet: split_function_schemas: List[str] split_saved_tensorlist: List[str] split_saved_tensorlist_optional: List[str] + split_saved_data: List[dict[str, str]] + split_variables: List[str] + split_unpacked_arg_names: List[str] + split_args_dict: Dict[str, List[str]] @staticmethod # pyre-ignore[3] @@ -525,27 +563,52 @@ def create( Returns: PT2ArgsSet object with the following attributes: split_function_args: List[str] - List of function arguments used in unified lookup and autograd functions - Tensors will be packed and pass as TensorList - e.g., ['at::TensorList momentum1', 'double eps', 'double weight_decay']. + Tensors will be packed and pass as TensorList. Auxillary arguments will be packed in dict. + e.g., ['at::TensorList momentum1', 'at::Dict optim_int']. split_function_arg_names: List[str] - List of argument names used in unified lookup and autograd functions - e.g., ['momentum1', 'eps', 'weight_decay']. + e.g., ['momentum1', 'optim_int', 'optim_float']. split_function_schemas: List[str] - List of arguments used in unified lookup and autograd functions in the schema format e.g., ['Tensor[] momentum1', 'float eps', 'float weight_decay']. split_saved_tensorlist: List[str] - List of tensor names that are packed into tensorlist and will be unpacked in PT2 autograd function. e.g., ['momentum1']. split_saved_tensorlist_optional: List[str] - List of tensor names that are packed into tensorlist but are optional and will be unpacked in PT2 autograd function e.g., ['row_counter']. + split_saved_data: List[dict[str, str]] - List of non-tensor arguments that are saved for backward + split_unpacked_arg_names: List[str] - List of argument names, unrolled from list + e.g., ['momentum1', 'eps', 'weight_decay', 'iter']. + split_args_dict: Dict[str, List[str]] - Dict of optim arguments' types containing the argument names of that type. + e.g., if an optimizer only has an int argument called iter, the dict will look like: + {'optim_tensor': [], 'optim_int': ['iter'], 'optim_float': [], 'optim_bool': []} """ split_function_arg_names = [] split_function_args = [] split_function_schemas = [] split_saved_tensorlist = [] split_saved_tensorlist_optional = [] + split_saved_data = [] + split_variables = [] + split_unpacked_arg_names = [] + has_optim_tensor = False # optim tensors here are optional tensor + has_optim_int = False + has_optim_float = False + has_optim_bool = False + split_args_dict = { + "optim_tensor": [], + "optim_int": [], + "optim_float": [], + "optim_bool": [], + } + # list of symint args to be appended after optim_xxx args + # since they have default values + symint_list: List[OptimItem] = [] + for s in arg_spec: if s.name == "learning_rate_tensor": split_function_arg_names.append(s.name) + split_unpacked_arg_names.append(s.name) split_function_args.append(tensor_arg(s.name)) split_function_schemas.append(tensor_arg(s.name)) + split_variables.append(f"ret.push_back(Variable()); // {s.name}") elif s.ty in ( ArgType.TENSOR, ArgType.INT_TENSOR, @@ -553,31 +616,125 @@ def create( ArgType.PLACEHOLDER_TENSOR, ): name = s.name - split_function_arg_names.append(name) + split_unpacked_arg_names.append(name) if s.is_optional: - split_function_args.append(optional_tensorlist_arg(name)) - split_function_schemas.append(schema_optional_tensorlist_arg(name)) split_saved_tensorlist_optional.append(name) + split_args_dict["optim_tensor"].append(s.name) + has_optim_tensor = True else: split_function_args.append( tensor_list_arg_no_default(name, pass_by_ref=False) ) + split_function_arg_names.append(name) split_function_schemas.append( schema_tensor_list_arg_no_default(name) ) split_saved_tensorlist.append(name) + split_variables.append( + f"ret.push_back(Variable()); // {s.name}_dev or host" + ) + split_variables.append( + f"ret.push_back(Variable()); // {s.name}_placements" + ) + split_variables.append( + f"ret.push_back(Variable()); // {s.name}_offsets" + ) + split_variables.append("if (" + name + "_host.numel() == 0) {") + split_variables.append( + f"ret.push_back(Variable()); // {s.name}_uvm" + ) + split_variables.append("}") else: - split_function_arg_names.append(s.name) - split_function_args.append(make_function_arg(s.ty, s.name, s.default)) - split_function_schemas.append( - make_function_schema_arg(s.ty, s.name, s.default) - ) + if s.ty == ArgType.INT: + # iter is passed in aux_int + if s.name != "iter": + split_args_dict["optim_int"].append(s.name) + split_saved_data.append( + ( + s.name, + f'optim_int[{len(split_args_dict["optim_int"]) - 1}]', + make_ivalue_cast(s.ty), + "int64_t", + ) + ) + has_optim_int = True + elif s.ty == ArgType.SYM_INT: + symint_list.append(s) + split_saved_data.append( + ( + s.name, + "", + make_ivalue_cast(s.ty), + "c10::SymInt", + ) + ) + elif s.ty == ArgType.FLOAT: + split_args_dict["optim_float"].append(s.name) + split_saved_data.append( + ( + s.name, + f'optim_float[{len(split_args_dict["optim_float"])- 1}]', + make_ivalue_cast(s.ty), + "double", + ) + ) + has_optim_float = True + elif s.ty == ArgType.BOOL: + split_args_dict["optim_bool"].append(s.name) + split_saved_data.append( + ( + s.name, + f'optim_bool[{len(split_args_dict["optim_bool"])- 1}]', + make_ivalue_cast(s.ty), + "bool", + ) + ) + has_optim_bool = True + else: + raise ValueError(f"Unsupported type {s.ty}") + split_unpacked_arg_names.append(s.name) + + def append_lists(type_name: str) -> None: + """ + Append the list as one argument to the list of function arguments, schemas, names and saved_variables. + e.g., if type_name is "tensor", optim_tensor will be appended with the corresponding syntax. + + Parameters: + type_name (str) - type name of the list to be appended + + Returns: + None + """ + split_function_args.append(list_arg(type_name)) + split_function_schemas.append(schema_list_arg(type_name)) + split_function_arg_names.append(f"optim_{type_name}") + split_variables.append(f"ret.push_back(Variable()); // optim_{type_name}") + + if has_optim_tensor: + append_lists("tensor") + if has_optim_int: + append_lists("int") + if has_optim_float: + append_lists("float") + if has_optim_bool: + append_lists("bool") + for s in symint_list: + split_function_arg_names.append(s.name) + split_function_args.append(make_function_arg(s.ty, s.name, s.default)) + split_function_schemas.append( + make_function_schema_arg(s.ty, s.name, s.default) + ) + split_variables.append(f"ret.push_back(Variable()); // {s.name}") return PT2ArgsSet( split_function_args=split_function_args, split_function_arg_names=split_function_arg_names, split_function_schemas=split_function_schemas, split_saved_tensorlist=split_saved_tensorlist, split_saved_tensorlist_optional=split_saved_tensorlist_optional, + split_saved_data=split_saved_data, + split_variables=split_variables, + split_unpacked_arg_names=split_unpacked_arg_names, + split_args_dict=split_args_dict, ) @@ -637,12 +794,14 @@ def create( if s.is_optional: has_optional_tensors = True - # Optional tensors are converted to tensor in autograd functions - # Reorganize arguments for wrapper, backend and kernel functions + # Optim arg order: non-optional tensors, learning_rate_tensor, non-tensors, optional tensors + # The optional tensors are converted to Tensor in autograd functions + # Hence, need to reorganize such that the tensors come before non-tensors which have default values values + # This is used in wrapper, backend and kernel functions if has_optional_tensors: - # Arg order: non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors, + # reordered args for split_arg_spec: non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors split_arg_spec = reorder_args(split_arg_spec) - # Arg order: non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors + # reordered args for kernel_split_arg_spec: non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors kernel_split_arg_spec = reorder_args(kernel_split_arg_spec) # Compute placeholder tensor combinations diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index a6ccbd7ed..6a2669c11 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -171,10 +171,10 @@ enum SSDTensor { {%- endif %} BT_block_size, max_segment_length_per_warp, - {%- if optimizer != "none" and not dense %} + {%- if not dense %} + {%- if optimizer != "none" %} stochastic_rounding, {%- endif %} - {%- if not dense %} info_B_num_bits, info_B_mask_int64, {%- endif %} @@ -311,42 +311,41 @@ enum SSDTensor { hash_size_cumsum, total_hash_size_bits, indices, - {%- if not nobag and dense and not vbe %} + {%- if dense and nobag %} + offsets + {%- else %} offsets, + {%- endif %} + {%- if not nobag %} pooling_mode, indice_weights, + {%- if dense and not vbe %} feature_requires_grad - {%- elif not nobag %} - offsets, - pooling_mode, - indice_weights, - feature_requires_grad, - {%- elif nobag and dense and not vbe %} - offsets {%- else %} - offsets, + feature_requires_grad, {%- endif %} + {%- endif %} {# /* if not nobag */ #} {%- if not dense %} lxu_cache_locations, uvm_cache_stats, - {%- endif %} - {%- if optimizer != "none" and not dense %} + {%- if optimizer != "none" %} gradient_clipping, max_gradient, stochastic_rounding, {%- endif %} + {%- endif %} {# /* if not dense */ #} {%- if vbe %} B_offsets, vbe_output_offsets_feature_rank, vbe_B_offsets_rank_per_feature, max_B, max_B_feature_rank, - {%- endif %} - {%- if vbe and not dense %} - vbe_output_size, - {%- elif vbe and dense %} + {%- if dense %} vbe_output_size - {%- endif %} + {%- else %} + vbe_output_size, + {%- endif %} {# /* if dense */ #} + {%- endif %} {# /* if vbe */ #} {%- if not dense %} is_experimental, use_uniq_cache_locations_bwd, @@ -359,12 +358,12 @@ enum SSDTensor { iter, {%- endif %} gwd_lower_bound, - {%- endif %} + {%- endif %} {# /* if is_gwd */ #} {%- if ssd %} ssd_tensors.value(), {%- endif %} {{ args.split_function_arg_names_autograd | join(", ") }} - {%- endif %} + {%- endif %} {# /* if not dense */ #} )[0]; {%- endmacro %} @@ -577,20 +576,19 @@ class {{ autograd_func }} : const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, const Tensor& indices, - {%- if not nobag and dense and not vbe %} + {%- if dense and nobag %} + const Tensor& offsets + {%- else %} const Tensor& offsets, + {%- endif %} + {%- if not nobag %} const int64_t pooling_mode, const std::optional& indice_weights, + {%- if dense and not vbe %} const std::optional& feature_requires_grad - {%- elif not nobag %} - const Tensor& offsets, - const int64_t pooling_mode, - const std::optional& indice_weights, - const std::optional& feature_requires_grad, - {%- elif nobag and dense and not vbe %} - const Tensor& offsets {%- else %} - const Tensor& offsets, + const std::optional& feature_requires_grad, + {%- endif %} {%- endif %} {%- if not dense %} const Tensor& lxu_cache_locations, @@ -619,7 +617,7 @@ class {{ autograd_func }} : const int64_t iter, {%- endif %} const double gwd_lower_bound, - {%- endif %} + {%- endif %} {#-/* if is_gwd */#} {%- if ssd %} const at::TensorList& ssd_tensors, {%- endif %} @@ -633,14 +631,13 @@ class {{ autograd_func }} : const c10::SymInt max_B_feature_rank, const c10::SymInt vbe_output_size {%- endif %} - {%- endif %}) { + {%- endif %} {# /* if not dense */ #}) { const auto T = weights_offsets.sym_numel(); {%- if vbe %} const auto B_offsets_ = B_offsets.value_or(Tensor()); const auto vbe_output_offsets_feature_rank_ = vbe_output_offsets_feature_rank.value_or(Tensor()); const auto vbe_B_offsets_rank_per_feature_ = vbe_B_offsets_rank_per_feature.value_or(Tensor()); - const c10::SymInt max_B_ = max_B; {%- else %} const auto max_B_ = offsets.sym_size(0) / T; diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index b224c3e70..9425c3cca 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -38,12 +38,16 @@ #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/split_embeddings_utils.h" #include "fbgemm_gpu/config/feature_gates.h" +#include "fbgemm_gpu/utils/tensor_utils.h" +#include "torch/csrc/autograd/record_function_ops.h" #include "torch/csrc/autograd/record_function_ops.h" {%- if has_vbe_support %} #include "fbgemm_gpu/utils/pt2_autograd_utils.h" {%- endif %} +#include "pt2_arg_utils.h" + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -92,7 +96,7 @@ enum SSDTensor { const Tensor& /*weights_dev*/, {%- if not dense %} const Tensor& /*weights_uvm*/, - const Tensor& /*lxu_cache_weights*/, + const Tensor& /*weights_lxu_cache*/, const Tensor& /*weights_placements*/, {%- endif %} const Tensor& /*weights_offsets*/, @@ -136,7 +140,7 @@ enum SSDTensor { weights_host, flatten_weights_dev, weights_uvm, - lxu_cache_weights, + weights_lxu_cache, weights_placements, weights_offsets, {%- if nobag %} @@ -155,7 +159,7 @@ enum SSDTensor { {%- endif %} {# /* if not nobag */ #} {%- if not dense %} {{ "ssd_tensors[SSDTensor::ROW_ADDRS]" if ssd else "lxu_cache_locations" }}, - uvm_cache_stats_, + uvm_cache_stats, {%- endif %} {%- if not nobag %} {%- if vbe %} @@ -214,6 +218,7 @@ enum SSDTensor { {%- else %} const Tensor& /*D_offsets*/, const c10::SymInt /*max_D*/, + const bool /*mixed_D*/, {%- endif %} const Tensor& /*hash_size_cumsum*/, const int64_t /*total_hash_size_bits*/, @@ -261,17 +266,22 @@ enum SSDTensor { grad_weights_dev = embedding_codegen{{ wdesc }}_backward_op.call( grad_output, + {% if dense %} + dev_weights, + {% else %} weights_host, weights_dev, weights_uvm, - lxu_cache_weights, + weights_lxu_cache, weights_placements, weights_offsets, + {% endif %} {% if nobag %} D, {%- else %} D_offsets, max_D, + mixed_D, {%- endif %} {# /* if nobag */ #} hash_size_cumsum, total_hash_size_bits, @@ -283,16 +293,18 @@ enum SSDTensor { {%- endif %} {# /* if not nobag */ #} {%- if ssd %} ssd_row_addrs, - {%- else %} + {%- elif not dense %} lxu_cache_locations, {%- endif %} BT_block_size, max_segment_length_per_warp, + {%- if not dense %} {%- if optimizer != "none" %} stochastic_rounding, {%- endif %} info_B_num_bits, info_B_mask_int64, + {%- endif %} {# /* if not dense */ #} {%- if vbe %} B_offsets, vbe_row_output_offsets, @@ -311,7 +323,11 @@ enum SSDTensor { {%- endif %} gwd_lower_bound, {%- endif %} {# /* if is_gwd */ #} + {%- if dense %} + /*unused=*/0 + {%- else %} {{ args_pt2.split_function_arg_names | join(", ") }} + {%- endif %} {%- if not nobag %} , output_dtype {%- endif %} @@ -321,73 +337,65 @@ enum SSDTensor { record_trace->record.end(); } - return { - {%- if not dense %} - Tensor(), // placeholder autograd tensor - {%- endif %} - Variable(), // output_dtype - Variable(), // weights_host - grad_weights_dev, // weights_dev - {%- if not dense %} - Variable(), // weights_uvm - Variable(), // lxu_cache_weights - Variable(), // weights_placements - {%- endif %} - Variable(), // weights_offsets - {%- if nobag %} - Variable(), // D - {%- else %} - Variable(), // D_offsets - Variable(), // total_D - Variable(), // max_D - {%- endif %} - Variable(), // hash_size_cumsum - Variable(), //total_hash_size_bits - Variable(), // indices - Variable(), // offsets - {%- if not nobag %} - Variable(), // pooling_mode - grad_indice_weights, // indice_weights - Variable(), // feature_requires_grad - {%- endif %} - {%- if not dense %} - Variable(), // lxu_cache_locations - Variable(), // uvm_cache_stats - {%- endif %} - {%- if optimizer != "none" and not dense %} - Variable(), // gradient_clipping - Variable(), // max_gradient - Variable(), // stochastic_rounding - {%- endif %} - {%- if vbe %} - Variable(), // B_offsets - Variable(), // vbe_output_offsets_feature_rank - Variable(), // vbe_B_offsets_rank_per_feature - Variable(), // max_B - Variable(), // max_B_feature_rank - Variable(), // vbe_output_size - {%- endif %} - {%- if not dense %} - Variable(), // is_experimental - Variable(), // use_uniq_cache_locations_bwd - Variable(), // use_homogeneous_placements - {%- endif %} - {%- if is_gwd %} - {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} - Variable(), // prev_iter_dev - {%- endif %} - {%- if "iter" not in args_pt2.split_function_arg_names %} - Variable(), // iter - {%- endif %} - Variable(), // gwd_lower_bound - {%- endif %} - {%- if ssd %} - {%- for tensor in ssd_tensors %} - Variable(), // {{ tensor }} - {%- endfor %} - {%- endif %} - {{ args_pt2.split_variables | join(", ") }} - }; + // Number of returned gradients have to match the input to Autograd's forward + // The number of items in the tensorlist differ between devices and is determined at runtime + std::vector ret; + + {%- if not dense %} + ret.push_back(Variable()); // placeholder autograd tensor + {%- endif %} + ret.push_back(Variable()); // output_dtype + {%- if not dense %} + if (weights_host.numel() > 0) { + ret.push_back(Tensor()); // host_weights + } + else { + ret.push_back(grad_weights_dev); // dev_weights + ret.push_back(Variable()); // weights_uvm + ret.push_back(Variable()); // weights_lxu_cache + } + ret.push_back(Variable()); // weights_placement + {%- endif %} + ret.push_back(Variable()); // weights_offsets + {%- if nobag %} + ret.push_back(Variable()); // D + {%- else %} + ret.push_back(Variable()); // D_offsets + ret.push_back(Variable()); // total_D + ret.push_back(Variable()); // max_D + {%- endif %} + ret.push_back(Variable()); // hash_size_cumsum + ret.push_back(Variable()); // total_hash_size_bits + ret.push_back(Variable()); // indices + ret.push_back(Variable()); // offsets + {%- if not nobag %} + ret.push_back(Variable()); // pooling_mode + ret.push_back(grad_indice_weights); // indice_weights + ret.push_back(Variable()); // feature_requires_grad + {%- endif %} + {%- if vbe %} + {%- if dense %} + ret.push_back(Variable()); // B_offsets + ret.push_back(Variable()); // vbe_output_offsets_feature_rank + ret.push_back(Variable()); // vbe_B_offsets_rank_per_feature + {%- endif %} {# /* if dense */ #} + ret.push_back(Variable()); // max_B + ret.push_back(Variable()); // max_B_feature_rank + ret.push_back(Variable()); // vbe_output_size + {%- endif %} {# /* if vbe */ #} + {%- if not dense %} + ret.push_back(Variable()); // aux_tensor + ret.push_back(Variable()); // aux_int + ret.push_back(Variable()); // aux_float + ret.push_back(Variable()); // aux_bool + {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + ret.push_back(Variable()); // {{ tensor }} + {%- endfor %} + {%- endif %} + {{ args_pt2.unified_pt2.split_variables | join("\n") }} + return ret; {%- endmacro %} /* This macro generates a code blob that calls corresponding autograd function @@ -407,9 +415,11 @@ enum SSDTensor { placeholder_autograd_tensor, {%- endif %} output_dtype, + {%- if dense %} + dev_weights, + weights_offsets, + {%- else %} weights, - {%- if not dense %} - lxu_cache_weights, {%- endif %} {%- if nobag %} max_D, @@ -421,51 +431,35 @@ enum SSDTensor { hash_size_cumsum, total_hash_size_bits, indices, - {%- if not nobag and dense and not vbe %} + {%- if dense and nobag %} + offsets + {%- else %} offsets, + {%- endif %} + {%- if not nobag %} pooling_mode, indice_weights, + {%- if dense and not vbe %} feature_requires_grad - {%- elif not nobag %} - offsets, - pooling_mode, - indice_weights, - feature_requires_grad, - {%- elif nobag and dense and not vbe %} - offsets {%- else %} - offsets, - {%- endif %} - {%- if not dense %} - lxu_cache_locations, - uvm_cache_stats, + feature_requires_grad, {%- endif %} - {%- if optimizer != "none" and not dense %} - gradient_clipping, - max_gradient, - stochastic_rounding, {%- endif %} {%- if vbe %} + {%- if dense %} B_offsets, vbe_output_offsets_feature_rank, vbe_B_offsets_rank_per_feature, + {%- endif %} {# /* if dense */ #} max_B, max_B_feature_rank, vbe_output_size, - {%- endif %} + {%- endif %} {# /* if vbe */ #} {%- if not dense %} - is_experimental, - use_uniq_cache_locations_bwd, - use_homogeneous_placements, - {%- if is_gwd %} - {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} - prev_iter_dev, - {%- endif %} - {%- if "iter" not in args_pt2.split_function_arg_names %} - iter, - {%- endif %} - gwd_lower_bound, - {%- endif %} + aux_tensor, + aux_int, + aux_float, + aux_bool, {%- if ssd %} ssd_tensors.value(), {%- endif %} @@ -474,39 +468,69 @@ enum SSDTensor { )[0]; {%- endmacro %} -/* This macro generates a code blob for unpacking the tensor list +/* This macro generates a code blob for unpacking TensorList */ {%- macro unpack_tensorlist(name) %} - const Tensor {{ name }}_host = {{ name }}[0]; - const Tensor {{ name }}_dev = {{ name }}[1]; - const Tensor {{ name }}_uvm = {{ name }}[2]; - const Tensor {{ name }}_placements = {{ name }}[3]; - const Tensor {{ name }}_offsets = {{ name }}[4]; -{%- endmacro %} - -{%- macro unpack_tensorlist_optional(name) %} Tensor {{ name }}_host; Tensor {{ name }}_dev; Tensor {{ name }}_uvm; Tensor {{ name }}_placements; Tensor {{ name }}_offsets; - if ({{ name }}.has_value()) { - at::TensorList _{{ name }} = {{ name }}.value(); - {{ name }}_host = _{{ name }}[0]; - {{ name }}_dev = _{{ name }}[1]; - {{ name }}_uvm = _{{ name }}[2]; - {{ name }}_placements = _{{ name }}[3]; - {{ name }}_offsets = _{{ name }}[4]; + {%- if name == "weights" %} + Tensor {{ name }}_lxu_cache; + {%- endif %} + + if ({{ name }}.size() == 3) { + TENSOR_ON_CPU_OR_MTIA({{ name }}[0]); + TENSORS_EMPTY_OR_ON_SAME_DEVICE({{ name }}[0], {{ name }}[1]); + TENSORS_EMPTY_OR_ON_SAME_DEVICE({{ name }}[0], {{ name }}[2]); + {{ name }}_host = {{ name }}[0]; + {{ name }}_placements = {{ name }}[1]; + {{ name }}_offsets = {{ name }}[2]; + } + else if ({{ name }}.size() == {{ 5 if name == "weights" else 4 }}) { + TENSOR_ON_CUDA_GPU({{ name }}[0]); + TENSORS_EMPTY_OR_ON_SAME_DEVICE({{ name }}[0], {{ name }}[1]); + TENSORS_EMPTY_OR_ON_SAME_DEVICE({{ name }}[0], {{ name }}[2]); + TENSORS_EMPTY_OR_ON_SAME_DEVICE({{ name }}[0], {{ name }}[3]); + {%- if name == "weights" %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE({{ name }}[0], {{ name }}[4]); + {%- endif %} + {{ name }}_dev = {{ name }}[0]; + {{ name }}_uvm = {{ name }}[1]; + {{ name }}_placements = {{ name }}[2]; + {{ name }}_offsets = {{ name }}[3]; + {%- if name == "weights" %} + {{ name }}_lxu_cache = {{ name }}[4]; + {%- endif %} } - else{ - {{ name }}_host = at::empty({0}, weights_host.options()); - {{ name }}_dev = at::empty({0}, weights_dev.options()); - {{ name }}_uvm = at::empty({0}, weights_uvm.options()); - {{ name }}_placements = at::empty({0}, weights_placements.options()); - {{ name }}_offsets = at::empty({0}, weights_offsets.options()); + else { + TORCH_CHECK(false, "Invalid size of {{ name }}, expected 3 for CPU or {{ 5 if name == "weights" else 4 }} for CUDA but got ", {{ name }}.size()); } {%- endmacro %} +{%- macro get_optional_optim_tensor(name, suffix, idx) %} + auto {{ name }}_{{ suffix }} = GET_OPTIONAL_TENSOR_VALUE(optim_tensor[{{ idx }}], at::empty({0}, options)); + +{%- endmacro %} + +/* This macro generates a code blob for unpacking a list of optional tensors + We cannot do list of optional tensorlist. We need to pack optimizer optional tensors in a flatten manner. + For readability and programmability, we pass all unified args (i.e., 5 items), as opposed to passing per device (like above) + which needs to be determined at runtime. +*/ +{%- macro unpack_tensorlist_optional(name, arg_index) %} + at::TensorOptions options = weights_host.numel() > 0 ? weights_host.options() : weights_dev.options(); + {{ get_optional_optim_tensor(name, "host", arg_index * 5) }} + {{ get_optional_optim_tensor(name, "dev", arg_index * 5 + 1) }} + {{ get_optional_optim_tensor(name, "uvm", arg_index * 5 + 2) }} + {{ get_optional_optim_tensor(name, "placements", arg_index * 5 + 3) }} + {{ get_optional_optim_tensor(name, "offsets", arg_index * 5 + 4) }} +{%- endmacro %} +//////////////////////////////////////////////////////////////////////////////// +// MACROS +//////////////////////////////////////////////////////////////////////////////// +#define GET_OPTIONAL_TENSOR_VALUE(name, empty_tensor) name.has_value() ? name.value() : empty_tensor; //////////////////////////////////////////////////////////////////////////////// // Autograd Function Declarations @@ -552,10 +576,16 @@ class {{ autograd_func }} : static constexpr bool is_traceable = true; static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, + {%- if not dense %} const Tensor& placeholder_autograd_tensor, + {%- endif %} const int64_t output_dtype, + {%- if dense %} + const Tensor& dev_weights, + const Tensor& weights_offsets, + {%- else %} const at::TensorList weights, - const Tensor& lxu_cache_weights, + {%- endif %} {%- if not nobag %} const Tensor& D_offsets, const c10::SymInt total_D, @@ -572,32 +602,21 @@ class {{ autograd_func }} : const std::optional& indice_weights, const std::optional& feature_requires_grad, {%- endif %} - const Tensor& lxu_cache_locations, - std::optional uvm_cache_stats, - {%- if optimizer != "none" %} - const bool gradient_clipping, - const double max_gradient, - const bool stochastic_rounding, - {%- endif %} {%- if vbe %} + {%- if dense %} const std::optional& B_offsets, const std::optional& vbe_output_offsets_feature_rank, const std::optional& vbe_B_offsets_rank_per_feature, + {%- endif %} {# /* if dense */ #} const c10::SymInt max_B, const c10::SymInt max_B_feature_rank, const c10::SymInt vbe_output_size, - {%- endif %} - const bool is_experimental, - const bool use_uniq_cache_locations_bwd, - const bool use_homogeneous_placements, - {%- if is_gwd %} - {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} - const std::optional& prev_iter_dev, - {%- endif %} - {%- if "iter" not in args_pt2.split_function_arg_names %} - const int64_t iter, - {%- endif %} - const double gwd_lower_bound, + {%- endif %} {# /* if vbe */ #} + {%- if not dense %} + std::vector> aux_tensor, + std::vector aux_int, + std::vector aux_float, + c10::List aux_bool, {%- endif %} {%- if ssd %} const at::TensorList& ssd_tensors, @@ -609,16 +628,19 @@ class {{ autograd_func }} : {%- for arg_name in args_pt2.unified_pt2.split_saved_tensorlist %} {{ unpack_tensorlist(arg_name) }} {%- endfor %} + {%- if "optim_tensor" in args_pt2.unified_pt2.split_function_arg_names %} + TORCH_CHECK(optim_tensor.size() % 5 == 0); + {%- endif %} {%- for arg_name in args_pt2.unified_pt2.split_saved_tensorlist_optional %} - {{ unpack_tensorlist_optional(arg_name) }} + {{ unpack_tensorlist_optional(arg_name, loop.index0) }} {%- endfor %} const auto T = weights_offsets.sym_numel(); - {%- if vbe %} - const auto B_offsets_ = B_offsets.value_or(Tensor()); - const auto vbe_output_offsets_feature_rank_ = vbe_output_offsets_feature_rank.value_or(Tensor()); - const auto vbe_B_offsets_rank_per_feature_ = vbe_B_offsets_rank_per_feature.value_or(Tensor()); + {%- if vbe %} + const auto B_offsets_ = GET_OPTIONAL_TENSOR_VALUE(aux_tensor[IDX_B_OFFSETS], Tensor()); + const auto vbe_output_offsets_feature_rank_ = GET_OPTIONAL_TENSOR_VALUE(aux_tensor[IDX_VBE_OUTPUT_OFFSETS_FEATURE_RANK], Tensor()); + const auto vbe_B_offsets_rank_per_feature_ = GET_OPTIONAL_TENSOR_VALUE(aux_tensor[IDX_VBE_B_OFFSETS_RANK_PER_FEATURE], Tensor()); const c10::SymInt max_B_ = max_B; {%- else %} const auto max_B_ = offsets.sym_size(0) / T; @@ -648,11 +670,15 @@ class {{ autograd_func }} : "{{ fwd_mdesc }}_tbe_fwd" + op_annotation); ctx->saved_data["op_annotation"] = op_annotation; } - + {%- if not dense %} // NOTE: The `local_uvm_cache_stats` variable held by the nn.Module has dtype int32_t // TODO: Hook up with frontend code - const auto uvm_cache_stats_ = uvm_cache_stats - .value_or(at::empty({0}, weights_uvm.options().dtype(at::kInt))); + at::TensorOptions uvm_options = weights_host.numel() > 0 ? weights_host.options() : weights_dev.options(); + const auto uvm_cache_stats = GET_OPTIONAL_TENSOR_VALUE(aux_tensor[IDX_UVM_CACHE_STATS], at::empty({0}, uvm_options.dtype(at::kInt))); + TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); + const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); + const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + {%- endif %} // Default values for Dynamo tracing // SymInt does not support bitshifts operator @@ -709,20 +735,29 @@ class {{ autograd_func }} : {%- endif %} // vbe {%- if is_gwd %} - const auto prev_iter_dev_ = prev_iter_dev.value_or(Tensor()); + {%- if "prev_iter" in args_pt2.unified_pt2.split_function_arg_names %} + const auto prev_iter_dev_ = GET_OPTIONAL_TENSOR_VALUE(prev_iter_dev, Tensor()); + {%- else %} + const auto prev_iter_dev_ = GET_OPTIONAL_TENSOR_VALUE(aux_tensor[IDX_PREV_ITER_DEV], Tensor()); + {%- endif %} {%- endif %} {%- if not nobag %} - const auto indice_weights_value = indice_weights.value_or(Tensor()); + const auto indice_weights_value = GET_OPTIONAL_TENSOR_VALUE(indice_weights, Tensor()); {%- endif %} ctx->save_for_backward({ + {%- if dense %} + dev_weights, + weights_offsets, + {%- else %} weights_host, weights_dev, weights_uvm, - lxu_cache_weights, + weights_lxu_cache, weights_placements, weights_offsets, + {%- endif %} {%- if not nobag %} D_offsets, {%- endif %} @@ -733,7 +768,9 @@ class {{ autograd_func }} : indice_weights_value, feature_requires_grad.value_or(Tensor()), {%- endif %} + {%- if not dense %} lxu_cache_locations, + {%- endif %} {%- if vbe %} B_offsets_, vbe_row_output_offsets, @@ -752,35 +789,46 @@ class {{ autograd_func }} : {%- if not nobag %} ctx->saved_data["max_D"] = max_D; + ctx->saved_data["mixed_D"] = static_cast(aux_bool[IDX_MIXED_D]); ctx->saved_data["pooling_mode"] = pooling_mode; {%- else %} ctx->saved_data["D"] = D; {%- endif %} ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits; - {%- if optimizer != "none" %} - ctx->saved_data["gradient_clipping"] = gradient_clipping; - ctx->saved_data["max_gradient"] = max_gradient; - ctx->saved_data["stochastic_rounding"] = stochastic_rounding; + {%- if optimizer != "none" and not dense %} + ctx->saved_data["gradient_clipping"] = static_cast(aux_bool[IDX_GRADIENT_CLIPPING]); + ctx->saved_data["max_gradient"] = aux_float[IDX_MAX_GRADIENT]; + ctx->saved_data["stochastic_rounding"] = static_cast(aux_bool[IDX_STOCHASTIC_ROUNDING]); {%- endif %} {#-/* if optimizer != "none" */#} ctx->saved_data["info_B_num_bits"] = info_B_num_bits; const auto info_B_mask_int64 = static_cast(info_B_mask); ctx->saved_data["info_B_mask"] = info_B_mask_int64; - ctx->saved_data["use_uniq_cache_locations_bwd"] = use_uniq_cache_locations_bwd; - ctx->saved_data["use_homogeneous_placements"] = use_homogeneous_placements; - {%- if is_gwd %} - {%- if "iter" not in args_pt2.split_function_arg_names %} - ctx->saved_data["iter"] = iter; + {%- if not dense %} + ctx->saved_data["use_uniq_cache_locations_bwd"] = static_cast(aux_bool[IDX_USE_UNIQ_CACHE_LOCATIONS_BWD]); + ctx->saved_data["use_homogeneous_placements"] = static_cast(aux_bool[IDX_USE_HOMOGENEOUS_PLACEMENTS]); {%- endif %} + const auto iter = aux_int[IDX_ITER]; + ctx->saved_data["iter"] = iter; + {%- if is_gwd %} + const auto gwd_lower_bound = aux_float[IDX_GWD_LOWER_BOUND]; ctx->saved_data["gwd_lower_bound"] = gwd_lower_bound; {%- endif %} {%- if not nobag %} ctx->saved_data["output_dtype"] = output_dtype; {%- endif %} - {%- for (var, _) in args_pt2.saved_data %} + {%- if not dense %} + // unpack optim args + {%- for (var, dict_val, _, type) in args_pt2.unified_pt2.split_saved_data %} + {%- if type == "bool" %} + bool {{ var }} = {{ dict_val }}; + {%- elif type != "c10::SymInt" %} + auto {{ var }} = {{ dict_val }}; + {%- endif %} ctx->saved_data["{{ var }}"] = {{ var }}; {%- endfor %} + {%- endif %} {%- if optimizer == "none" %} // Flatten @@ -827,12 +875,17 @@ static torch::autograd::variable_list backward( torch::autograd::variable_list grad_outputs) { const auto saved = ctx->get_saved_variables(); auto savedItr = std::begin(saved); + {%- if dense %} + auto dev_weights = *savedItr++; + auto weights_offsets = *savedItr++; + {%- else %} auto weights_host = *savedItr++; auto weights_dev = *savedItr++; auto weights_uvm = *savedItr++; - auto lxu_cache_weights = *savedItr++; + auto weights_lxu_cache = *savedItr++; auto weights_placements = *savedItr++; auto weights_offsets = *savedItr++; + {%- endif %} {%- if not nobag %} auto D_offsets = *savedItr++; {%- endif %} @@ -843,7 +896,9 @@ static torch::autograd::variable_list backward( auto indice_weights = *savedItr++; auto feature_requires_grad = *savedItr++; {%- endif %} + {%- if not dense %} auto lxu_cache_locations = *savedItr++; + {%- endif %} {%- if vbe %} auto B_offsets = *savedItr++; auto vbe_row_output_offsets = *savedItr++; @@ -864,35 +919,39 @@ static torch::autograd::variable_list backward( {%- if not nobag %} auto max_D = ctx->saved_data["max_D"].toInt(); + const auto mixed_D = ctx->saved_data["mixed_D"].toBool(); auto pooling_mode = ctx->saved_data["pooling_mode"].toInt(); {%- else %} auto D = ctx->saved_data["D"].toInt(); {%- endif %} auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt(); - {%- if optimizer != "none" %} + {%- if optimizer != "none" and not dense %} auto gradient_clipping = ctx->saved_data["gradient_clipping"].toBool(); auto max_gradient = ctx->saved_data["max_gradient"].toDouble(); auto stochastic_rounding = ctx->saved_data["stochastic_rounding"].toBool(); {%- endif %} {#-/* if optimizer != "none" */#} [[maybe_unused]] const int32_t info_B_num_bits = ctx->saved_data["info_B_num_bits"].toInt(); [[maybe_unused]] const int64_t info_B_mask_int64 = ctx->saved_data["info_B_mask"].toInt(); + {%- if not dense %} const auto use_uniq_cache_locations_bwd = ctx->saved_data["use_uniq_cache_locations_bwd"].toBool(); const auto use_homogeneous_placements = ctx->saved_data["use_homogeneous_placements"].toBool(); - {%- if is_gwd %} - {%- if "iter" not in args_pt2.split_function_arg_names %} + {%- endif %} + {%- if is_gwd or "iter" in args_pt2.unified_pt2.split_unpacked_arg_names %} const auto iter = ctx->saved_data["iter"].toInt(); {%- endif %} + {%- if is_gwd %} const auto gwd_lower_bound = ctx->saved_data["gwd_lower_bound"].toDouble(); {%- endif %} {%- if not nobag %} auto output_dtype = ctx->saved_data["output_dtype"].toInt(); {%- endif %} - - {%- for (var, ivalue_cast) in args_pt2.saved_data %} + {%- if not dense %} + {%- for (var, _ , ivalue_cast, type) in args_pt2.unified_pt2.split_saved_data %} auto {{ var }} = ctx->saved_data["{{ var }}"].{{ ivalue_cast }}(); {%- endfor %} + {%- endif %} const static bool is_annotate_trace_enabled = config::is_feature_enabled( config::FeatureGateName::TBE_ANNOTATE_KINETO_TRACE); @@ -914,7 +973,7 @@ static torch::autograd::variable_list backward( #endif using torch::autograd::Variable; - {%- if optimizer != "none" %} + {%- if optimizer != "none" and not dense %} auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0]; {%- else %} auto& grad_output = grad_outputs[0]; @@ -948,19 +1007,24 @@ static torch::autograd::variable_list backward( .findSchemaOrThrow("fbgemm::{{ grad_indice_weights_op }}", "") .typed& indice_weights, const std::optional& feature_requires_grad, - const Tensor& lxu_cache_locations, - {%- if optimizer != "none" %} - const bool gradient_clipping, - const double max_gradient, - const bool stochastic_rounding, + const int64_t output_dtype, + {%- if not dense %} + const std::vector>& aux_tensor, + const std::vector& aux_int, + const std::vector& aux_float, + c10::List aux_bool, {%- endif %} {{ args_pt2.unified_pt2.split_function_args | join(", ") }}, - const int64_t output_dtype = static_cast(SparseType::FP32), - const std::optional& B_offsets = std::nullopt, - const std::optional& vbe_output_offsets_feature_rank = std::nullopt, - const std::optional& vbe_B_offsets_rank_per_feature = std::nullopt, const c10::SymInt max_B = -1, const c10::SymInt max_B_feature_rank = -1, - const c10::SymInt vbe_output_size = -1, - const bool is_experimental_tbe = false, // formerly named is_experimental - const bool use_uniq_cache_locations_bwd = false, - const bool use_homogeneous_placements = false, - const std::optional& uvm_cache_stats = std::nullopt, - {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} - const std::optional& prev_iter_dev = std::nullopt, - {%- endif %} - {%- if "iter" not in args_pt2.split_function_arg_names %} - const int64_t iter = 0, - {%- endif %} - const bool apply_global_weight_decay = false, {%- if ssd %} - const std::optional& ssd_tensors = std::nullopt, + const c10::SymInt vbe_output_size = -1, + const std::optional& ssd_tensors = std::nullopt + {%- else %} + const c10::SymInt vbe_output_size = -1 {%- endif %} - const double gwd_lower_bound = 0 ) { + {%- if has_gpu_support or has_cpu_support %} {%- if not dense %} @@ -1097,7 +1153,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( static auto is_tbev2_enabled = config::is_feature_enabled(config::FeatureGateName::TBE_V2); // Set to experimental if either the feature is enabled in JK, or the user specifies to use TBEv2 - const auto is_experimental = is_tbev2_enabled || is_experimental_tbe; + aux_bool[IDX_IS_EXPERIMENTAL_TBE] = is_tbev2_enabled || aux_bool[IDX_IS_EXPERIMENTAL_TBE]; {%- endif %} {%- if ssd %} @@ -1108,10 +1164,12 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( {%- if has_vbe_support %} // has vbe support and on gpu - if (B_offsets.has_value()) { + if (aux_tensor[IDX_B_OFFSETS].has_value()) { {%- if has_global_weight_decay_support and not ssd %} // vbe and has gwd support - if (apply_global_weight_decay && weight_decay > 0) { + // if weight_decay arg is not passed or < 0 even though apply_global_weight_decay is True, we don't do gwd + // TODO: add check to ensure weight decay exists + if (aux_bool[IDX_APPLY_GLOBAL_WEIGHT_DECAY] && optim_float[{{args_pt2.unified_pt2.split_args_dict["optim_float"].index("weight_decay")}}] > 0) { {{ call_autograd(nobag=False, vbe=True, is_gwd=True) }} } {%- endif %} {#-/* if has_global_weight_decay_support */ #} @@ -1122,7 +1180,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( {%- if has_global_weight_decay_support and not ssd %} // has gwd support - if (apply_global_weight_decay && weight_decay > 0) { + if (aux_bool[IDX_APPLY_GLOBAL_WEIGHT_DECAY] && optim_float[{{args_pt2.unified_pt2.split_args_dict["optim_float"].index("weight_decay")}}] > 0) { // not vbe and gwd {{ call_autograd(nobag=False, vbe=False, is_gwd=True) }} } @@ -1146,11 +1204,15 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - {%- set op_name = "{}_embedding_codegen_lookup_{}_function_pt2".format(bwd_mdesc, optimizer) %} + {%- set op_name = "{}_embedding_codegen_lookup_{}_function_pt2".format(fwd_mdesc, optimizer) %} m.def("{{ op_name }}(" + {%- if dense %} + " Tensor dev_weights, " + " Tensor weights_offsets, " + {%- else %} " Tensor placeholder_autograd_tensor, " " Tensor[] weights, " - " Tensor lxu_cache_weights, " + {%- endif %} " Tensor D_offsets, " " SymInt total_D, " " SymInt max_D, " @@ -1161,35 +1223,22 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " int pooling_mode, " " Tensor? indice_weights, " " Tensor? feature_requires_grad, " - " Tensor lxu_cache_locations, " - {%- if optimizer != "none" %} - " bool gradient_clipping, " - " float max_gradient, " - " bool stochastic_rounding, " - {%- endif %} + " int output_dtype, " + {%- if not dense %} + " Tensor?[] aux_tensor, " + " int[] aux_int, " + " float[] aux_float, " + " bool[] aux_bool, " " {{ args_pt2.unified_pt2.split_function_schemas | join(", ") }}, " - " int output_dtype=0, " - " Tensor? B_offsets=None, " - " Tensor? vbe_output_offsets_feature_rank=None, " - " Tensor? vbe_B_offsets_rank_per_feature=None, " " SymInt max_B=-1, " " SymInt max_B_feature_rank=-1, " + {%- if ssd %} " SymInt vbe_output_size=-1, " - " bool is_experimental_tbe=False, " - " bool use_uniq_cache_locations_bwd=False, " - " bool use_homogeneous_placements=False, " - " Tensor? uvm_cache_stats=None," - {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} - " Tensor? prev_iter_dev=None, " - {%- endif %} - {%- if "iter" not in args_pt2.split_function_arg_names %} - " int iter=0, " + " Tensor[]? ssd_tensors=None" + {%- else %} + " SymInt vbe_output_size=-1 " {%- endif %} - " bool apply_global_weight_decay=False, " - {%- if ssd %} - " Tensor[]? ssd_tensors=None," {%- endif %} - " float gwd_lower_bound=0 " ") -> Tensor", {PT2_COMPLIANT_TAG}); // We're playing a funny trick here: we're using the autograd diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp index 7c405d4f9..807fc94d5 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp @@ -175,6 +175,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p const Tensor& weights_offsets, const Tensor& D_offsets, const int64_t max_D, + const bool mixed_D, const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, const Tensor& indices, @@ -207,19 +208,19 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ backward_op }}", "") .typed