diff --git a/iree/turbine/kernel/compiler/vector_codegen.py b/iree/turbine/kernel/compiler/vector_codegen.py index ae3d21ab0..6becdab94 100644 --- a/iree/turbine/kernel/compiler/vector_codegen.py +++ b/iree/turbine/kernel/compiler/vector_codegen.py @@ -854,9 +854,10 @@ def cast_kernel_buffer( return value, MemRefType(ir_type), py_type -def cast_vector( - emitter: ThreadEmitter, value, *, element_type: Optional[IrType] = None -): +def cast_vector(emitter: ThreadEmitter, + value, + *, + element_type: Optional[IrType] = None): proxy_value = cast_py_value(emitter, value) # Cast scalar types correctly first. @@ -864,6 +865,7 @@ def cast_vector( # Implicit scalar type promotion. proxy_value = ScalarBuilder.to_dtype(proxy_value, element_type) + print(f"proxy_value {proxy_value}") value = proxy_value.ir_value # After scalar promotion, promote to vector. diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index a95f33a2f..a09c29909 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -1069,7 +1069,8 @@ def handle_set_symbol(emitter: WaveEmitter, node: fx.Node): raise ValidationError("Malformed arguments") from e register = cast_vector(emitter, register, element_type=IndexType.get()) - emitter.dynamic_dims[symbol] = _to_scalar(register) + # emitter.dynamic_dims[symbol] = _to_scalar(register) + emitter.dynamic_dims[symbol] = register ###############################################################################