Skip to content

Commit

Permalink
replace calls to find_common_type with result_type
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd committed Jul 30, 2023
1 parent 65d20f4 commit fb92a30
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 10 deletions.
2 changes: 1 addition & 1 deletion doc/design.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ engineer's behalf. The transformations are recorded on
:math:`\rho_{ij}:\mathcal{T}[a_1]\times\mathcal{T}[a_2]\times\ldots\mapsto
\mathcal{C}\left(\mathcal{T}[a_1]\times\mathcal{T}[a_2]\times\ldots\right)`,
where, :math:`\mathcal{C}` is a mapping on the numeric data-types as provided
by :func:`numpy.find_common_type`.
by :func:`numpy.result_type`.


The transformation writer prescribes the transformation space without making
Expand Down
2 changes: 1 addition & 1 deletion src/feinsum/canonicalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def get_einsum_dag(einsum: FusedEinsum) -> Map[EinsumGraphNode,
(frozenset(einsum.value_to_dtype[use] for use in uses)
for uses in use_row),
frozenset())
out_dtype = np.find_common_type(list(use_dtypes), [])
out_dtype = np.result_type(*use_dtypes)
einsum_dag[dtype_to_node[out_dtype]].add(array_to_node[output_name])

# }}}
Expand Down
7 changes: 3 additions & 4 deletions src/feinsum/codegen/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,16 @@ def generate_loopy(einsum: FusedEinsum,

for i_output in range(einsum.noutputs):
arg_to_dtype: Dict[Argument, np.dtype[Any]] = {
EinsumOperand(ioperand): np.find_common_type({value_to_dtype[use]
for use in uses},
[])
EinsumOperand(ioperand): np.result_type(*{value_to_dtype[use]
for use in uses})
for ioperand, uses in enumerate(einsum
.use_matrix[i_output])}

for name_in_lpy_knl, name_in_feinsum, args in (
zip(result_name_in_lpy_knl[i_output],
schedule.result_names,
schedule.arguments)):
dtype = np.find_common_type({arg_to_dtype[arg] for arg in args}, [])
dtype = np.result_type(*{arg_to_dtype[arg] for arg in args})
value_to_dtype = value_to_dtype.set(name_in_lpy_knl, dtype)
arg_to_dtype[IntermediateResult(name_in_feinsum)] = dtype

Expand Down
7 changes: 3 additions & 4 deletions src/feinsum/tuning/impls/cogent.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,9 @@ def transform(t_unit: lp.TranslationUnit,

# Verify register file usage
# --------------------------
if (rx or 1) * (ry or 1) * np.find_common_type(
[ensm.value_to_dtype[ensm_A],
ensm.value_to_dtype[ensm_B]],
[]).itemsize > REG_FILE_SPACE_PER_WI:
if (rx or 1) * (ry or 1) * np.result_type(
ensm.value_to_dtype[ensm_A],
ensm.value_to_dtype[ensm_B]).itemsize > REG_FILE_SPACE_PER_WI:
raise fnsm.InvalidParameterError("Exceeds register file limits")

assert tx <= 32 and ty <= 32
Expand Down

0 comments on commit fb92a30

Please sign in to comment.