diff --git a/doc/design.rst b/doc/design.rst index 4a735a0..bb5d13b 100644 --- a/doc/design.rst +++ b/doc/design.rst @@ -74,7 +74,7 @@ engineer's behalf. The transformations are recorded on :math:`\rho_{ij}:\mathcal{T}[a_1]\times\mathcal{T}[a_2]\times\ldots\mapsto \mathcal{C}\left(\mathcal{T}[a_1]\times\mathcal{T}[a_2]\times\ldots\right)`, where, :math:`\mathcal{C}` is a mapping on the numeric data-types as provided -by :func:`numpy.find_common_type`. +by :func:`numpy.result_type`. The transformation writer prescribes the transformation space without making diff --git a/src/feinsum/canonicalization.py b/src/feinsum/canonicalization.py index 4bccf61..3dbf24f 100644 --- a/src/feinsum/canonicalization.py +++ b/src/feinsum/canonicalization.py @@ -322,7 +322,7 @@ def get_einsum_dag(einsum: FusedEinsum) -> Map[EinsumGraphNode, (frozenset(einsum.value_to_dtype[use] for use in uses) for uses in use_row), frozenset()) - out_dtype = np.find_common_type(list(use_dtypes), []) + out_dtype = np.result_type(*use_dtypes) einsum_dag[dtype_to_node[out_dtype]].add(array_to_node[output_name]) # }}} diff --git a/src/feinsum/codegen/loopy.py b/src/feinsum/codegen/loopy.py index a21c62d..39651be 100644 --- a/src/feinsum/codegen/loopy.py +++ b/src/feinsum/codegen/loopy.py @@ -175,9 +175,8 @@ def generate_loopy(einsum: FusedEinsum, for i_output in range(einsum.noutputs): arg_to_dtype: Dict[Argument, np.dtype[Any]] = { - EinsumOperand(ioperand): np.find_common_type({value_to_dtype[use] - for use in uses}, - []) + EinsumOperand(ioperand): np.result_type(*{value_to_dtype[use] + for use in uses}) for ioperand, uses in enumerate(einsum .use_matrix[i_output])} @@ -185,7 +184,7 @@ def generate_loopy(einsum: FusedEinsum, zip(result_name_in_lpy_knl[i_output], schedule.result_names, schedule.arguments)): - dtype = np.find_common_type({arg_to_dtype[arg] for arg in args}, []) + dtype = np.result_type(*{arg_to_dtype[arg] for arg in args}) value_to_dtype = value_to_dtype.set(name_in_lpy_knl, dtype) arg_to_dtype[IntermediateResult(name_in_feinsum)] = dtype diff --git a/src/feinsum/tuning/impls/cogent.py b/src/feinsum/tuning/impls/cogent.py index 75a9b9e..109baed 100644 --- a/src/feinsum/tuning/impls/cogent.py +++ b/src/feinsum/tuning/impls/cogent.py @@ -201,10 +201,9 @@ def transform(t_unit: lp.TranslationUnit, # Verify register file usage # -------------------------- - if (rx or 1) * (ry or 1) * np.find_common_type( - [ensm.value_to_dtype[ensm_A], - ensm.value_to_dtype[ensm_B]], - []).itemsize > REG_FILE_SPACE_PER_WI: + if (rx or 1) * (ry or 1) * np.result_type( + ensm.value_to_dtype[ensm_A], + ensm.value_to_dtype[ensm_B]).itemsize > REG_FILE_SPACE_PER_WI: raise fnsm.InvalidParameterError("Exceeds register file limits") assert tx <= 32 and ty <= 32