Skip to content

Commit

Permalink
Debug OFFSET does not accept vector
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasvasilache authored and ftynse committed Feb 7, 2025
1 parent f2af934 commit f4ec891
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
8 changes: 5 additions & 3 deletions iree/turbine/kernel/compiler/vector_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,16 +854,18 @@ def cast_kernel_buffer(
return value, MemRefType(ir_type), py_type


def cast_vector(
emitter: ThreadEmitter, value, *, element_type: Optional[IrType] = None
):
def cast_vector(emitter: ThreadEmitter,
value,
*,
element_type: Optional[IrType] = None):
proxy_value = cast_py_value(emitter, value)

# Cast scalar types correctly first.
if element_type and not ShapedType.isinstance(proxy_value.ir_value.type):
# Implicit scalar type promotion.
proxy_value = ScalarBuilder.to_dtype(proxy_value, element_type)

print(f"proxy_value {proxy_value}")
value = proxy_value.ir_value

# After scalar promotion, promote to vector.
Expand Down
3 changes: 2 additions & 1 deletion iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,8 @@ def handle_set_symbol(emitter: WaveEmitter, node: fx.Node):
raise ValidationError("Malformed arguments") from e

register = cast_vector(emitter, register, element_type=IndexType.get())
emitter.dynamic_dims[symbol] = _to_scalar(register)
# emitter.dynamic_dims[symbol] = _to_scalar(register)
emitter.dynamic_dims[symbol] = register


###############################################################################
Expand Down

0 comments on commit f4ec891

Please sign in to comment.