diff --git a/iree/turbine/kernel/compiler/vector_codegen.py b/iree/turbine/kernel/compiler/vector_codegen.py index 6becdab94..5586c8507 100644 --- a/iree/turbine/kernel/compiler/vector_codegen.py +++ b/iree/turbine/kernel/compiler/vector_codegen.py @@ -854,10 +854,9 @@ def cast_kernel_buffer( return value, MemRefType(ir_type), py_type -def cast_vector(emitter: ThreadEmitter, - value, - *, - element_type: Optional[IrType] = None): +def cast_vector( + emitter: ThreadEmitter, value, *, element_type: Optional[IrType] = None +): proxy_value = cast_py_value(emitter, value) # Cast scalar types correctly first. @@ -865,7 +864,7 @@ def cast_vector(emitter: ThreadEmitter, # Implicit scalar type promotion. proxy_value = ScalarBuilder.to_dtype(proxy_value, element_type) - print(f"proxy_value {proxy_value}") + # print(f"proxy_value {proxy_value}") value = proxy_value.ir_value # After scalar promotion, promote to vector. diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index a09c29909..c8297f1ab 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -721,6 +721,7 @@ def _build_mask( def _construct_gather_scatter_indices( emitter: WaveEmitter, + # TODO TODO TODO fix typo symbolc_shape: tuple[IndexExpr], index: tuple[IndexExpr], mapping: IndexMapping, @@ -728,6 +729,7 @@ def _construct_gather_scatter_indices( is_read: bool, dynamic_vals: tuple[Any, ...], is_contiguous: bool, + vector_shaped_symbols={}, ) -> tuple[OpResult, OpResult, OpResult]: # Apply symbolc_shape order to indices, e.g. if original mapping is # {M: iter(0), N: iter(1)} and symbolc_shape is (N, M), result will @@ -795,34 +797,46 @@ def extract0(src): offsets = [] strides = strides_from_symbolic_shape(idxc, symbolc_shape, allow_mixed_shapes=True) start_indices_offset = _compute_offset(start_indices, strides) - for i in range(elements_per_thread): - # Update fastest dim, i.e. in case of identity mapping it will - # be equivalent to just vector.load - subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)] - subs[fastest_dim] = (subs[fastest_dim][0], start_indices_orig[fastest_dim] + i) - indices = [i.subs(subs) for i in index_mapping] - # First, we build indices as if resulting gather/scatter `start_indices` - # are 0 as mapping expression may depend on absolute value of index - # (e.g. `index % 32`). Then we adjust for the non-0 `start_indices` by - # subtracting computed previously linear `start_indices_offset`. For - # simple cases like transpose, the resulting expression should fold into - # simple constant while more complex expressions may requires actual - # arith ops on dynamic values. - offset = _compute_offset(indices, strides) - start_indices_offset - offset = subs_idxc(offset) - - if offset.is_number: - # If resulted offset sympy expr is convertible to int constant it - # will be directly encoded into `arith.constant`. - # For non-constant expressions, we will generate a real sequence of - # arith ops and then `vector.insertelement` them into offsets vec. - offset = int(offset) - else: - need_dynamic_offsets = True - break + # TODO TODO TODO we don't necessarily care if they are vector shaped, but + # if they are indxed by the fastest varying dimension? + # Note that we may want to "expand" the symbol to per-element + # copies and trigger `need_dynamic_offests` below + if len(start_indices_offset.free_symbols.intersection(vector_shaped_symbols)) != 0: + need_dynamic_offsets = True + + if not need_dynamic_offsets: + for i in range(elements_per_thread): + # Update fastest dim, i.e. in case of identity mapping it will + # be equivalent to just vector.load + subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)] + subs[fastest_dim] = ( + subs[fastest_dim][0], + start_indices_orig[fastest_dim] + i, + ) + indices = [i.subs(subs) for i in index_mapping] + + # First, we build indices as if resulting gather/scatter `start_indices` + # are 0 as mapping expression may depend on absolute value of index + # (e.g. `index % 32`). Then we adjust for the non-0 `start_indices` by + # subtracting computed previously linear `start_indices_offset`. For + # simple cases like transpose, the resulting expression should fold into + # simple constant while more complex expressions may requires actual + # arith ops on dynamic values. + offset = _compute_offset(indices, strides) - start_indices_offset + offset = subs_idxc(offset) + + if offset.is_number: + # If resulted offset sympy expr is convertible to int constant it + # will be directly encoded into `arith.constant`. + # For non-constant expressions, we will generate a real sequence of + # arith ops and then `vector.insertelement` them into offsets vec. + offset = int(offset) + else: + need_dynamic_offsets = True + break - offsets.append(offset) + offsets.append(offset) offsets_vec_type = VectorType.get([elements_per_thread], IndexType.get()) if need_dynamic_offsets: @@ -834,13 +848,22 @@ def extract0(src): ) subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)] # Last item in `subs` corresponds to last item in `start_indices_orig` - # which is fastest changing dim. - # Replacing last element with `idxc.iota(elements_per_thread)` will - # generate vectorized index code, each element in it corresponding to - # individual vector element index. + # which is fastest changing dim. Replacing last element with + # `idxc.iota(elements_per_thread)` will generate vectorized index code, + # each element in it corresponding to individual vector element index. + # + # TODO TODO TODO: vector shaped symbol means we can't just iota here + # instead we should just take the value of the symbol; we should instead + # somehow get different values of OFFSET (or other vector shaped + # symbols) into the `get_sympy_index` below + # + # we also need to take care if there are several symbols, especially a + # mix of constant and non-constant symbols... + # + # what happens when there are no symbols? subs[-1] = ( subs[-1][0], - start_indices_orig[-1] + idxc.iota(elements_per_thread), + start_indices_orig[-1], # + idxc.iota(elements_per_thread), ) dynamic_vals_map = { sym: val @@ -905,6 +928,15 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): dyn_vals = tuple( cast_vector(emitter, reg, element_type=IndexType.get()) for reg in dyn_vals ) + + # TODO TODO TODO we can sink this down, actually... + vector_shaped_symbols = set( + sym + for sym, value in emitter.dynamic_dims.items() + if isinstance(value.type, ShapedType) + and math.prod(ShapedType(value.type).shape) != 1 + ) + start_indices, offsets_vec, mask = _construct_gather_scatter_indices( emitter=emitter, symbolc_shape=input_shape, @@ -914,6 +946,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): is_read=True, dynamic_vals=dyn_vals, is_contiguous=get_custom(node).is_contiguous_vec(), + vector_shaped_symbols=vector_shaped_symbols, ) zero = get_constant_attr(0, element_type) @@ -1326,8 +1359,7 @@ def handle_and_op(lhs: Value, rhs: Value) -> OpResult: if _is_integer_like_type(element_type): result = arith_d.andi(lhs, rhs) else: - raise ValidationError( - f"Found unhandled operand type for le: {element_type}") + raise ValidationError(f"Found unhandled operand type for le: {element_type}") return result @@ -1362,6 +1394,7 @@ def handle_minimum(lhs: Value, rhs: Value) -> OpResult: ) return result + ############################################################################### # Unary math Ops ############################################################################### @@ -1750,7 +1783,7 @@ def handle_cast(emitter: WaveEmitter, node: fx.Node): dst_vector_type = VectorType.get(src_vector_type.shape, dst_elem_type) if src_vector_type == dst_vector_type: - emitter.bind_node_proxy(node, vector_src) + emitter.bind_node_proxy(node, IRProxyValue(vector_src)) return is_src_float = _is_float_type(src_elem_type) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 661208770..1fceffe1a 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -1420,6 +1420,7 @@ def check_is_mapping_contiguous( return True # TODO: Better dyn vals analysis. + # TODO TODO TODO: we also need to check if there are additional sybols in the mapping if mapping.num_dynamic_vals != 0: return False @@ -1437,6 +1438,15 @@ def check_is_mapping_contiguous( index_mapping = tuple(subs_idxc(i) for i in index_mapping) iters = mapping.iters + # TODO TODO TODO at this point, if the symbols present in the index are not + # known to be scalars or themselves contiguous, we shouldn't say the read is contiguous + # + # TODO TODO TODO we should thread through the fact that _some_ symbols have vector shape + for expr in index_mapping: + if len(expr.free_symbols - mapping.iters.keys()) != 0: + print("### avoided") + return False + subs = [(sym, sym + int(i == len(iters) - 1)) for i, sym in enumerate(iters)] diff = [ approximate_difference( diff --git a/playground/attention_with_rpe_template.py b/playground/attention_with_rpe_template.py index 0c11a111f..a4c78c6aa 100644 --- a/playground/attention_with_rpe_template.py +++ b/playground/attention_with_rpe_template.py @@ -18,9 +18,12 @@ from iree.turbine.kernel.wave.templates.attention_common import AttentionShape -def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, - dynamic_dims: bool, - max_context_length: Optional[int]): +def get_vanilla_attention_kernel( + shape: AttentionShape, + mfma_variant: MMAType, + dynamic_dims: bool, + max_context_length: Optional[int], +): # RPE ZERO = tkl.sym.ZERO OFFSET = tkl.sym.OFFSET @@ -44,9 +47,7 @@ def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD # Expose user-constraints - constraints: list[tkw.Constraint] = [ - tkw.WorkgroupConstraint(M, BLOCK_M, 0) - ] + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] @@ -65,11 +66,7 @@ def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, threads_per_wave=64, waves_per_block=(4, 1, 1), mma_type=mfma_variant[1], - vector_shapes={ - B: 0, - M: Mvec, - N: Nvec - }, + vector_shapes={B: 0, M: Mvec, N: Nvec}, ) ] @@ -79,17 +76,9 @@ def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, i = tkw.IndexMapping.iterator(0) j = tkw.IndexMapping.iterator(1) k = tkw.IndexMapping.iterator(2) - mapping = tkw.IndexMapping(num_iterators=3, - inputs={ - B: i, - N: j, - M: k - }, - outputs={ - B: i, - M: k, - N: j - }) + mapping = tkw.IndexMapping( + num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + ) offset_mapping = tkw.IndexMapping( num_iterators=1, @@ -99,9 +88,11 @@ def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, use_t5_rpe = max_context_length is not None if use_t5_rpe: - rpe_layout = tkl.MemoryLayout(shape=[ - max_context_length, - ]) + rpe_layout = tkl.MemoryLayout( + shape=[ + max_context_length, + ] + ) assert use_t5_rpe, "use_t5_rpe needed until rpe arg can DCE without crashing" @tkw.wave(constraints) @@ -138,16 +129,18 @@ def repeat( # When fusing into the FA variant, adding locally before the max and # the partial softmax should be equivalent. if use_t5_rpe: - ZERO = tkl.Register[M, K2, tkl.i64](0) - MAX = tkl.Register[M, K2, tkl.i64](max_context_length) + ZERO = tkl.Register[B, M, K2, tkl.i64](0) + MAX = tkl.Register[B, M, K2, tkl.i64](max_context_length) # 1. Indices i and j broadcasted along K2 with a twist: # here we use *static* information that is *implicitly* encoded # in the *transformation*: under the distribution constraints # specified we know that the shape [M] will eventually resolve # to [1] and can thus be "cast + broadcast" to [K2]. i = tkw.self_index(M, tkl.i64, elements_per_thread=1) - i = tkw.broadcast(i, target_shape=[M, K2]) - j = tkw.self_index(K2, tkl.i64, elements_per_thread=1) + i = tkw.broadcast(i, target_shape=[B, M, K2]) + j = tkw.self_index( + K2, tkl.i64, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK + ) # 2. Clip i - j to the proper bucket in [0, max_context_length] # to represent the following: @@ -163,18 +156,18 @@ def repeat( # idx = tkw.minimum(idx, MAX) # select variant. - idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX), - i - j, ZERO) + idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX), i - j, ZERO) # 3. Read indirect into the 1-D rpe array via offset_mapping. tkw.set_symbol(OFFSET, idx) # offset will have shape [M, K2] rpe_reg = tkw.read( rpe, mapping=offset_mapping, - elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + elements_per_thread=LOAD_ELEMS_PER_THREAD_QK, + ) # 4. Tadaaaa. - x_j = x_j + rpe_reg + tkw.cast(ZERO * idx, tkl.i64) + x_j = x_j + rpe_reg + tkw.cast(ZERO * idx, tkl.f32) m_j = tkw.max(x_j, partial_max, dim=K2) e_delta_max = tkw.exp2(partial_max - m_j) @@ -191,19 +184,13 @@ def repeat( res_max, res_sum, res_mm = repeat reciprocal_sum = tkw.reciprocal(res_sum) res = res_mm * reciprocal_sum - tkw.write(res, - c, - mapping=mapping, - elements_per_thread=STORE_ELEMS_PER_THREAD) + tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) hyperparams = { ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - LOAD_ELEMS_PER_THREAD_QK: - get_mfma_load_elems_per_thread(mfma_variant[0]), - LOAD_ELEMS_PER_THREAD_PV: - get_mfma_load_elems_per_thread(mfma_variant[1]), - STORE_ELEMS_PER_THREAD: - get_mfma_store_elems_per_thread(mfma_variant[1]), + LOAD_ELEMS_PER_THREAD_QK: get_mfma_load_elems_per_thread(mfma_variant[0]), + LOAD_ELEMS_PER_THREAD_PV: get_mfma_load_elems_per_thread(mfma_variant[1]), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant[1]), BLOCK_B: 1, BLOCK_M: 128, BLOCK_N: 64, diff --git a/playground/attention_with_rpe_test.py b/playground/attention_with_rpe_test.py index b0d161ef6..f9fc3ebef 100644 --- a/playground/attention_with_rpe_test.py +++ b/playground/attention_with_rpe_test.py @@ -15,7 +15,8 @@ from iree.turbine.kernel.wave.constraints import MMAType from iree.turbine.kernel.wave.templates.attention_common import AttentionShape from iree.turbine.kernel.wave.templates.vanilla_attention import ( - get_vanilla_attention_kernel as get_vanilla_tkw_attention_kernel) + get_vanilla_attention_kernel as get_vanilla_tkw_attention_kernel, +) from iree.turbine.kernel.wave.utils import ( device_randn, device_zeros, @@ -23,8 +24,8 @@ to_default_device, ) from attention_with_rpe_template import ( - get_vanilla_attention_kernel as - get_vanilla_tkw_attention_with_rpe_output_kernel) + get_vanilla_attention_kernel as get_vanilla_tkw_attention_with_rpe_output_kernel, +) torch.manual_seed(0) torch.set_printoptions( @@ -37,23 +38,26 @@ ### TKW Harness def run(fun: Callable, hparams, *args) -> Any: with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + activities=[torch.profiler.ProfilerActivity.CUDA] + ) as prof: with torch.no_grad(): # Disable gradient calculations with TestLaunchContext( - hparams, - canonicalize=True, - compile_config={"print-ir-after": "all"}, - run=True, - run_config=get_default_run_config(), - run_bench=False, - schedule=False, - use_scheduling_barriers=False, + hparams, + canonicalize=True, + compile_config={"print-ir-after": "all"}, + run=True, + run_config=get_default_run_config(), + run_bench=False, + schedule=False, + use_scheduling_barriers=False, ): fun(*args) print( prof.key_averages(group_by_input_shape=True).table( - sort_by="self_cuda_time_total", row_limit=10)) + sort_by="self_cuda_time_total", row_limit=10 + ) + ) ################################################################################# @@ -64,8 +68,9 @@ def run(fun: Callable, hparams, *args) -> Any: shape.query_seq_len = 128 shape.kv_seq_len = 128 -assert shape.num_query_heads == shape.num_kv_heads, \ - "expected query and kv to have the same number of heads!" +assert ( + shape.num_query_heads == shape.num_kv_heads +), "expected query and kv to have the same number of heads!" q_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size) k_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size) @@ -90,40 +95,44 @@ def run(fun: Callable, hparams, *args) -> Any: # Applied pre-softmax on the MMA'ed result so f32. # Provision more room for clipping and adding 0 at the boundaries. rpe = device_zeros(1000 + max_context_length + 2, dtype=torch.float32) -rpe = rpe[:max_context_length + 2].view(max_context_length + 2) +rpe = rpe[: max_context_length + 2].view(max_context_length + 2) rpe.copy_(device_randn(max_context_length + 2, dtype=torch.float32)) rpe[0] = 0 rpe[max_context_length + 1] = 0 -def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, - dtype): +def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, dtype): positions = to_default_device(torch.arange(sequence_length)) pos_diff = positions.unsqueeze(1) - positions.unsqueeze(0) - mask = to_default_device((pos_diff >= 0) - & (pos_diff <= max_context_length)) + mask = to_default_device((pos_diff >= 0) & (pos_diff <= max_context_length)) rpe_cond = device_zeros(sequence_length, sequence_length, dtype=dtype) rpe_cond[mask] = rpe[pos_diff[mask]] return rpe_cond # rpe_cond is used by torch only -rpe_cond = t5_rpe_masked_cond(rpe, - max_context_length=max_context_length, - sequence_length=shape.kv_seq_len, - dtype=tkw_attention_with_rpe_output.dtype) +rpe_cond = t5_rpe_masked_cond( + rpe, + max_context_length=max_context_length, + sequence_length=shape.kv_seq_len, + dtype=tkw_attention_with_rpe_output.dtype, +) ################################################################################# # TKW BASE ATTENTION ################################################################################# ### RPE version -tkw_attention_with_rpe, hyperparams, dynamic_symbols, dynamic_symbols_map = \ - get_vanilla_tkw_attention_with_rpe_output_kernel( - shape, - mfma_variant=[MMAType.F32_16x16x16_F16, - MMAType.F32_16x16x16_F16], - dynamic_dims=False, - max_context_length = max_context_length + 2) +( + tkw_attention_with_rpe, + hyperparams, + dynamic_symbols, + dynamic_symbols_map, +) = get_vanilla_tkw_attention_with_rpe_output_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16], + dynamic_dims=False, + max_context_length=max_context_length + 2, +) def attention_with_rpe(tq, tk, tv, trpe, toutput): @@ -131,24 +140,41 @@ def attention_with_rpe(tq, tk, tv, trpe, toutput): print(mb.module_op) -run(attention_with_rpe, hyperparams, q * dk_sqrt * log2e, k, - v.permute([0, 2, 1]), rpe, tkw_attention_with_rpe_output) +run( + attention_with_rpe, + hyperparams, + q * dk_sqrt * log2e, + k, + v.permute([0, 2, 1]), + rpe * log2e, + tkw_attention_with_rpe_output, +) ### Reference version -tkw_attention, hyperparams, dynamic_symbols, dynamic_symbols_map = \ - get_vanilla_tkw_attention_kernel( - shape, - mfma_variant=[MMAType.F32_16x16x16_F16, - MMAType.F32_16x16x16_F16], - dynamic_dims=False) +( + tkw_attention, + hyperparams, + dynamic_symbols, + dynamic_symbols_map, +) = get_vanilla_tkw_attention_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16], + dynamic_dims=False, +) def attention(tq, tk, tv, toutput): tkw_attention(tq, tk, tv, toutput) -run(attention, hyperparams, q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), - tkw_attention_output) +run( + attention, + hyperparams, + q * dk_sqrt * log2e, + k, + v.permute([0, 2, 1]), + tkw_attention_output, +) tkw_rpe_delta_output = tkw_attention_with_rpe_output - tkw_attention_output # print(tkw_rpe_delta_output) @@ -157,7 +183,8 @@ def attention(tq, tk, tv, toutput): # TORCH ATTENTION and ATTENTION + RPE ################################################################################# torch_attention_ref_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=None) + q, k, v, attn_mask=None +) a = torch.matmul(q, k.transpose(-1, -2)) * dk_sqrt torch_attention_output = torch.matmul(torch.softmax(a, dim=-1), v) @@ -165,23 +192,24 @@ def attention(tq, tk, tv, toutput): # Sanity check that torch_attention_output and torch_attention_ref_output are # the same so we can inject RPE pre-softmax and compute the delta. # We will test that the delta post-softmax is the same for torch and TKW. -assert_close(torch_attention_output, - torch_attention_ref_output, - atol=2e-3, - rtol=2e-3) +assert_close(torch_attention_output, torch_attention_ref_output, atol=2e-3, rtol=2e-3) a += rpe_cond.unsqueeze(0) torch_attention_with_rpe_output = torch.matmul(F.softmax(a, dim=-1), v) torch_rpe_delta_output = torch_attention_with_rpe_output - torch_attention_output # Check basic attentions match as we expect. -assert_close(torch_attention_output.to(dtype=tkw_attention_output.dtype), - tkw_attention_output, - atol=2e-3, - rtol=2e-3) +assert_close( + torch_attention_output.to(dtype=tkw_attention_output.dtype), + tkw_attention_output, + atol=2e-3, + rtol=2e-3, +) # Check RPE attentions match as we expect. -assert_close(torch_rpe_delta_output.to(dtype=tkw_rpe_delta_output.dtype), - tkw_rpe_delta_output, - atol=2e-3, - rtol=2e-3) +assert_close( + torch_rpe_delta_output.to(dtype=tkw_rpe_delta_output.dtype), + tkw_rpe_delta_output, + atol=2e-3, + rtol=2e-3, +)