From b13238894de711cd81b634b565cf995fd7812c79 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 25 Apr 2025 13:02:09 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/util/dtype_util.h | 28 +++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index 2286ca50be..cf332b3541 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -228,7 +228,7 @@ enum class SupportedTensorDtypes { namespace internal { template -load_to_compute_fn get_load_to_compute_fn( +load_to_compute_fn get_load_to_compute_fn_impl( const Tensor& t, SupportedTensorDtypes dtypes) { switch (dtypes) { @@ -251,6 +251,10 @@ load_to_compute_fn get_load_to_compute_fn( return nullptr; } +// NOTE: applying the #ifdef EXECUTORCH_SELECTIVE_BUILD_DTYPE +// technique used for get_load_to_compute_fn in this path was a size +// regression rather than an improvement. Haven't fully investigated +// why; just be aware when trying to improve size further. template store_compute_to_tensor_fn get_store_compute_to_tensor_fn( const Tensor& t, @@ -285,6 +289,28 @@ store_compute_to_tensor_fn get_store_compute_to_tensor_fn( return nullptr; } +#ifndef EXECUTORCH_SELECTIVE_BUILD_DTYPE +inline constexpr const char kGenericElementwiseOpName[] = "generic_elementwise_op"; +#endif // EXECUTORCH_SELECTIVE_BUILD_DTYPE + +template +load_to_compute_fn get_load_to_compute_fn( + const Tensor& t, + SupportedTensorDtypes dtypes) { + // NOTE: Selective build relies on the operator name being passed + // here. When it's *not* active, using the same operator name + // everywhere saves on size because we don't require a new template + // instantiation for every operator. + return get_load_to_compute_fn_impl< + CTYPE_COMPUTE, +#ifdef EXECUTORCH_SELECTIVE_BUILD_DTYPE + op_name +#else // EXECUTORCH_SELECTIVE_BUILD_DTYPE + kGenericElementwiseOpName +#endif // EXECUTORCH_SELECTIVE_BUILD_DTYPE + >(t, dtypes); +} + bool check_tensor_dtype( const Tensor t, SupportedTensorDtypes dtypes,