Skip to content

Commit

Permalink
Unifying TBE API using List (Backend) (#3563)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#649


As the number of arguments in TBE keeps growing, some of the optimizers run into number of arguments limitation (i.e., 64) during pytorch operation registration. 

**For long-term growth and maintenance, we hence redesign TBE API by packing some of the arguments into list. Note that not all arguments are packed.**

We pack the arguments as a list for each type.
For **common** arguments, we pack 
- weights and arguments of type `Momentum` into TensorList
- other tensors and optional tensors to list of optional tensors `aux_tensor`
- `int` arguments into `aux_int`
- `float` arguments into `aux_float`
- `bool` arguments into `aux_bool`.

Similarly for **optimizer-specific** arguments, we pack
- arguments of type `Momentum` that are *__not__ optional* into TensorList
- *optional* tensors to list of optional tensors `optim_tensor`
- `int` arguments into `optim_int`
- `float` arguments into `optim_float`
- `bool` arguments into `optim_bool`.

We see issues with pytorch registration across packing SymInt in python-C++, so we unroll and pass SymInt arguments individually. 

**This significantly reduces number of arguments.** For example, `split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function`, which currently has 61 arguments only have 26 arguments with this API design. 

Please refer to the design doc on which arguments are packed and signature.
Design doc:
https://docs.google.com/document/d/1dCBg7dcf7Yq9FHVrvXsAmFtBxkDi9o6u0r-Ptd4UDPE/edit?tab=t.0#heading=h.6bip5pwqq8xb

Full signature for each optimizer lookup function will be provided shortly.

Reviewed By: sryap

Differential Revision: D68054868
  • Loading branch information
spcyppt authored and facebook-github-bot committed Jan 16, 2025
1 parent 5c76d93 commit e7408fa
Show file tree
Hide file tree
Showing 11 changed files with 631 additions and 295 deletions.
3 changes: 3 additions & 0 deletions fbgemm_gpu/cmake/tbe_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@
+ [
"gen_embedding_backward_split_common_device_kernel.cuh",
]
+ [
"pt2_arg_utils.h",
]
)

gen_defused_optim_templates = [
Expand Down
55 changes: 54 additions & 1 deletion fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
)
for filename in [
f"lookup_{optimizer}{sdesc}.py",
f"lookup_{optimizer}{sdesc}_pt2.py",
]:
template.write(
filename, is_fbcode=args.is_fbcode, ssd=ssd, **kwargs
Expand Down Expand Up @@ -331,6 +330,23 @@ def generate_rocm_backward_split(**kwargs: Any) -> None:
},
)

@staticmethod
def generate_backward_header(
aux_args: Dict[str, List[str]], aux_names: List[str]
) -> None:
"""
Generate a header file that contains enum of argument order from the dict
Parameters:
aux_args (Dict[str, List[str]]): a dict containing a list of arguments
aux_names (List[str]): names of the argument types (e.g. aux_tensor, aux_int, etc.)
Return:
None
"""
# Generate backward header for PT2 Autograd
template = CodeTemplate.load("training/pt2/pt2_arg_utils_template.h")
template.write(f"pt2_arg_utils.h", aux_args=aux_args, aux_names=aux_names)

@staticmethod
def generate_python_sources(
all_optimizers: List[str], ssd_optimizers: List[str]
Expand Down Expand Up @@ -375,6 +391,40 @@ def generate() -> None:
"actions_count",
]

aux_names = ["aux_tensor", "aux_int", "aux_float", "aux_bool"]
# This is a dict of auxilary arguments used in TBE PT2 interface where the aux
# arguments of a type are packed into a list for that type. This dict maintains the
# order of the arguments of each type.
aux_args: Dict[str, List[str]] = {
"aux_tensor": [
"B_offsets", # 0
"vbe_output_offsets_feature_rank", # 1
"vbe_B_offsets_rank_per_feature", # 2
"lxu_cache_locations", # 3
"uvm_cache_stats", # 4
"prev_iter_dev", # 5
],
"aux_int": [
"iter", # 0
],
"aux_float": [
"gwd_lower_bound", # 0
"max_gradient", # 1
],
"aux_bool": [
"is_experimental_tbe", # 0
"use_uniq_cache_locations_bwd", # 1
"use_homogeneous_placements", # 2
"apply_global_weight_decay", # 3
"gradient_clipping", # 4
"stochastic_rounding", # 5
"mixed_D", # 6
],
}
assert (
list(aux_args.keys()) == aux_names
), f"{aux_names} must match {aux_args.keys()}"

all_optimizers = []
ssd_optimizers = []

Expand All @@ -399,6 +449,9 @@ def generate() -> None:
BackwardSplitGenerator.generate_backward_grad()
BackwardSplitGenerator.generate_backward_indices()

# Generate headers for backwards
BackwardSplitGenerator.generate_backward_header(aux_args, aux_names)

BackwardSplitGenerator.generate_python_sources(all_optimizers, ssd_optimizers)


Expand Down
191 changes: 175 additions & 16 deletions fbgemm_gpu/codegen/genscript/optimizer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict
# pyre-ignore-all-errors[29]
# pyre-ignore-all-errors[53]
# flake8: noqa F401


Expand Down Expand Up @@ -205,6 +206,40 @@ def schema_bool_arg(name: str, default: bool = False) -> str:
return f"bool {name} = {default}"


def list_arg(ty: str) -> str:
"""
Returns a C++ argument for a list of optimizer arguments the given type.
Parameters:
ty (str) - type of the list e.g., "int", "float", "tensor"
Returns:
C++ arguemnt for a list of the given type e.g., for a list of int returns "std::vector<int> optim_int",
"""
return {
"tensor": "std::vector<std::optional<at::Tensor>> optim_tensor",
"int": "std::vector<int64_t> optim_int",
"float": "std::vector<double> optim_float",
"bool": "c10::List<bool> optim_bool",
}[ty]


def schema_list_arg(ty: str) -> str:
"""
Returns a C++ schema for a list of optimizer arguments the given type.
Parameters:
ty (str) - type of the list e.g., "int", "float", "tensor"
Returns:
C++ arguemnt for a list of the given type e.g., for a list of int returns "int[] optim_int",
"""
return {
"tensor": "Tensor?[] optim_tensor",
"int": "int[] optim_int",
"float": "float[] optim_float",
"bool": "bool[] optim_bool",
}[ty]


def optional_tensor_arg(name: str) -> str:
return f"std::optional<Tensor> {name} = std::nullopt"

Expand All @@ -230,7 +265,6 @@ def schema_optional_tensorlist_arg(name: str) -> str:


def make_kernel_arg(
# pyre-fixme[11]: Annotation `ArgType` is not defined as a type.
ty: ArgType,
name: str,
default: Union[int, float, None],
Expand Down Expand Up @@ -505,6 +539,10 @@ class PT2ArgsSet:
split_function_schemas: List[str]
split_saved_tensorlist: List[str]
split_saved_tensorlist_optional: List[str]
split_saved_data: List[dict[str, str]]
split_variables: List[str]
split_unpacked_arg_names: List[str]
split_args_dict: Dict[str, List[str]]

@staticmethod
# pyre-ignore[3]
Expand All @@ -525,59 +563,178 @@ def create(
Returns:
PT2ArgsSet object with the following attributes:
split_function_args: List[str] - List of function arguments used in unified lookup and autograd functions
Tensors will be packed and pass as TensorList
e.g., ['at::TensorList momentum1', 'double eps', 'double weight_decay'].
Tensors will be packed and pass as TensorList. Auxillary arguments will be packed in dict.
e.g., ['at::TensorList momentum1', 'at::Dict<std:string, int> optim_int'].
split_function_arg_names: List[str] - List of argument names used in unified lookup and autograd functions
e.g., ['momentum1', 'eps', 'weight_decay'].
e.g., ['momentum1', 'optim_int', 'optim_float'].
split_function_schemas: List[str] - List of arguments used in unified lookup and autograd functions in the schema format
e.g., ['Tensor[] momentum1', 'float eps', 'float weight_decay'].
split_saved_tensorlist: List[str] - List of tensor names that are packed into tensorlist and will be unpacked in
PT2 autograd function. e.g., ['momentum1'].
split_saved_tensorlist_optional: List[str] - List of tensor names that are packed into tensorlist but are optional
and will be unpacked in PT2 autograd function e.g., ['row_counter'].
split_saved_data: List[dict[str, str]] - List of non-tensor arguments that are saved for backward
split_unpacked_arg_names: List[str] - List of argument names, unrolled from list
e.g., ['momentum1', 'eps', 'weight_decay', 'iter'].
split_args_dict: Dict[str, List[str]] - Dict of optim arguments' types containing the argument names of that type.
e.g., if an optimizer only has an int argument called iter, the dict will look like:
{'optim_tensor': [], 'optim_int': ['iter'], 'optim_float': [], 'optim_bool': []}
"""
split_function_arg_names = []
split_function_args = []
split_function_schemas = []
split_saved_tensorlist = []
split_saved_tensorlist_optional = []
split_saved_data = []
split_variables = []
split_unpacked_arg_names = []
has_optim_tensor = False # optim tensors here are optional tensor
has_optim_int = False
has_optim_float = False
has_optim_bool = False
split_args_dict = {
"optim_tensor": [],
"optim_int": [],
"optim_float": [],
"optim_bool": [],
}
# list of symint args to be appended after optim_xxx args
# since they have default values
symint_list: List[OptimItem] = []

for s in arg_spec:
if s.name == "learning_rate_tensor":
split_function_arg_names.append(s.name)
split_unpacked_arg_names.append(s.name)
split_function_args.append(tensor_arg(s.name))
split_function_schemas.append(tensor_arg(s.name))
split_variables.append(f"ret.push_back(Variable()); // {s.name}")
elif s.ty in (
ArgType.TENSOR,
ArgType.INT_TENSOR,
ArgType.LONG_TENSOR,
ArgType.PLACEHOLDER_TENSOR,
):
name = s.name
split_function_arg_names.append(name)
split_unpacked_arg_names.append(name)
if s.is_optional:
split_function_args.append(optional_tensorlist_arg(name))
split_function_schemas.append(schema_optional_tensorlist_arg(name))
split_saved_tensorlist_optional.append(name)
split_args_dict["optim_tensor"].append(s.name)
has_optim_tensor = True
else:
split_function_args.append(
tensor_list_arg_no_default(name, pass_by_ref=False)
)
split_function_arg_names.append(name)
split_function_schemas.append(
schema_tensor_list_arg_no_default(name)
)
split_saved_tensorlist.append(name)
split_variables.append(
f"ret.push_back(Variable()); // {s.name}_dev or host"
)
split_variables.append(
f"ret.push_back(Variable()); // {s.name}_placements"
)
split_variables.append(
f"ret.push_back(Variable()); // {s.name}_offsets"
)
split_variables.append("if (" + name + "_host.numel() == 0) {")
split_variables.append(
f"ret.push_back(Variable()); // {s.name}_uvm"
)
split_variables.append("}")
else:
split_function_arg_names.append(s.name)
split_function_args.append(make_function_arg(s.ty, s.name, s.default))
split_function_schemas.append(
make_function_schema_arg(s.ty, s.name, s.default)
)
if s.ty == ArgType.INT:
# iter is passed in aux_int
if s.name != "iter":
split_args_dict["optim_int"].append(s.name)
split_saved_data.append(
(
s.name,
f'optim_int[{len(split_args_dict["optim_int"]) - 1}]',
make_ivalue_cast(s.ty),
"int64_t",
)
)
has_optim_int = True
elif s.ty == ArgType.SYM_INT:
symint_list.append(s)
split_saved_data.append(
(
s.name,
"",
make_ivalue_cast(s.ty),
"c10::SymInt",
)
)
elif s.ty == ArgType.FLOAT:
split_args_dict["optim_float"].append(s.name)
split_saved_data.append(
(
s.name,
f'optim_float[{len(split_args_dict["optim_float"])- 1}]',
make_ivalue_cast(s.ty),
"double",
)
)
has_optim_float = True
elif s.ty == ArgType.BOOL:
split_args_dict["optim_bool"].append(s.name)
split_saved_data.append(
(
s.name,
f'optim_bool[{len(split_args_dict["optim_bool"])- 1}]',
make_ivalue_cast(s.ty),
"bool",
)
)
has_optim_bool = True
else:
raise ValueError(f"Unsupported type {s.ty}")
split_unpacked_arg_names.append(s.name)

def append_lists(type_name: str) -> None:
"""
Append the list as one argument to the list of function arguments, schemas, names and saved_variables.
e.g., if type_name is "tensor", optim_tensor will be appended with the corresponding syntax.
Parameters:
type_name (str) - type name of the list to be appended
Returns:
None
"""
split_function_args.append(list_arg(type_name))
split_function_schemas.append(schema_list_arg(type_name))
split_function_arg_names.append(f"optim_{type_name}")
split_variables.append(f"ret.push_back(Variable()); // optim_{type_name}")

if has_optim_tensor:
append_lists("tensor")
if has_optim_int:
append_lists("int")
if has_optim_float:
append_lists("float")
if has_optim_bool:
append_lists("bool")
for s in symint_list:
split_function_arg_names.append(s.name)
split_function_args.append(make_function_arg(s.ty, s.name, s.default))
split_function_schemas.append(
make_function_schema_arg(s.ty, s.name, s.default)
)
split_variables.append(f"ret.push_back(Variable()); // {s.name}")
return PT2ArgsSet(
split_function_args=split_function_args,
split_function_arg_names=split_function_arg_names,
split_function_schemas=split_function_schemas,
split_saved_tensorlist=split_saved_tensorlist,
split_saved_tensorlist_optional=split_saved_tensorlist_optional,
split_saved_data=split_saved_data,
split_variables=split_variables,
split_unpacked_arg_names=split_unpacked_arg_names,
split_args_dict=split_args_dict,
)


Expand Down Expand Up @@ -637,12 +794,14 @@ def create(
if s.is_optional:
has_optional_tensors = True

# Optional tensors are converted to tensor in autograd functions
# Reorganize arguments for wrapper, backend and kernel functions
# Optim arg order: non-optional tensors, learning_rate_tensor, non-tensors, optional tensors
# The optional tensors are converted to Tensor in autograd functions
# Hence, need to reorganize such that the tensors come before non-tensors which have default values values
# This is used in wrapper, backend and kernel functions
if has_optional_tensors:
# Arg order: non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors,
# reordered args for split_arg_spec: non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors
split_arg_spec = reorder_args(split_arg_spec)
# Arg order: non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors
# reordered args for kernel_split_arg_spec: non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors
kernel_split_arg_spec = reorder_args(kernel_split_arg_spec)

# Compute placeholder tensor combinations
Expand Down
Loading

0 comments on commit e7408fa

Please sign in to comment.