Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RPR on top of causal attention #453

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions iree/turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ def _is_float_type(type):
return isinstance(type, (BF16Type, F16Type, F32Type, F64Type, Float8E4M3FNUZType))


def _is_index_type(type):
return isinstance(type, (IndexType))


def _is_integer_like_type(type):
return isinstance(type, (IntegerType, IndexType))

Expand Down
1 change: 1 addition & 0 deletions iree/turbine/kernel/_support/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def bitwidth(self):

bf16 = DataType("bf16")
bool = DataType("bool", "i1")
i1 = bool
i4 = DataType("i4")
i8 = DataType("i8")
i16 = DataType("i16")
Expand Down
8 changes: 5 additions & 3 deletions iree/turbine/kernel/compiler/vector_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,16 +854,18 @@ def cast_kernel_buffer(
return value, MemRefType(ir_type), py_type


def cast_vector(
emitter: ThreadEmitter, value, *, element_type: Optional[IrType] = None
):
def cast_vector(emitter: ThreadEmitter,
value,
*,
element_type: Optional[IrType] = None):
proxy_value = cast_py_value(emitter, value)

# Cast scalar types correctly first.
if element_type and not ShapedType.isinstance(proxy_value.ir_value.type):
# Implicit scalar type promotion.
proxy_value = ScalarBuilder.to_dtype(proxy_value, element_type)

print(f"proxy_value {proxy_value}")
value = proxy_value.ir_value

# After scalar promotion, promote to vector.
Expand Down
113 changes: 101 additions & 12 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..lang.wave_types import Memory, Register, IndexMapping
from ..lang.global_symbols import *
from .._support.indexing import IndexExpr, IndexSymbol, IndexSequence
from .._support.dtype import DataType
from .._support.dtype import DataType, i1
from .._support.regions import RegionGraph
from .base import OpDispatcher
import numpy as np
Expand All @@ -45,6 +45,14 @@ def allocate(
...


def self_index(
idx: IndexExpr,
dtype: DataType,
elements_per_thread: Optional[IndexExpr | int] = None,
) -> "Register":
...


def extract(
register: "Register",
offsets: tuple[IndexExpr],
Expand Down Expand Up @@ -140,6 +148,10 @@ def minimum(lhs: "Register", rhs: "Register") -> "Register":
...


def and_op(lhs: "Register", rhs: "Register") -> "Register":
...


def broadcast(
arg: "Register", target_shape: Optional[Sequence[IndexExpr | int]] = None
) -> "Register":
Expand Down Expand Up @@ -178,6 +190,10 @@ def reshape(inputs: Sequence["Register"]) -> "Register":
...


def select(cond: "Register", if_true: "Register", if_false: "Register") -> "Register":
...


def define_op(op_name: str) -> Callable[[T], T]:
def decorator(cls: T) -> T:
cls.tkw_op_name = op_name
Expand Down Expand Up @@ -680,14 +696,8 @@ def transform_index(
return index


@define_py_op(operator.add)
@define_py_op(operator.sub)
@define_py_op(operator.mul)
@define_py_op(operator.truediv)
@define_interface_op("maximum")
@define_interface_op("minimum")
@dataclass
class BinaryPyOp(CustomOp, ABC):
class BinaryOpBase(CustomOp, ABC):
"""
Represents an elementwise binary python operator.

Expand Down Expand Up @@ -715,21 +725,48 @@ def indexing_dims(self) -> list[IndexSymbol]:
def py_operator(self) -> str:
return self.tkw_op_name

def infer_type(self):
def infer_shape(self) -> Any:
lhs_type = get_custom(self.lhs).type
rhs_type = get_custom(self.rhs).type
has_same_type = has_same_custom_type(lhs_type, rhs_type)
if has_same_type:
self.type = lhs_type
return
return lhs_type.symbolic_shape

lhs_dim_set = set(lhs_type.symbolic_shape)
rhs_dim_set = set(rhs_type.symbolic_shape)
if lhs_dim_set.isdisjoint(rhs_dim_set):
raise ValueError(
"BinaryPyOp requires lhs and rhs shape to be at least broadcastable."
f" got {lhs_type.symbolic_shape} vs {rhs_type.symbolic_shape}"
)

# TODO: this logic looks suspicious. Specifically, there's no check that
# rhs_dim_set subsumes lhs_dim_set, they may partially overlap.
broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type
self.type = broadcasted_type
return broadcasted_type.symbolic_shape


@define_py_op(operator.add)
@define_py_op(operator.sub)
@define_py_op(operator.mul)
@define_py_op(operator.truediv)
@define_interface_op("maximum")
@define_interface_op("minimum")
@define_interface_op("and_op")
@dataclass
class BinaryPyOp(BinaryOpBase, ABC):
def infer_type(self):
self.type = Register[(*self.infer_shape(), get_custom(self.lhs).type.dtype)]


@define_py_op(operator.gt)
@define_py_op(operator.ge)
@define_py_op(operator.lt)
@define_py_op(operator.le)
@dataclass
class ComparisonPyOp(BinaryOpBase, ABC):
def infer_type(self):
self.type = Register[(*self.infer_shape(), i1)]


@define_interface_op("log2")
Expand Down Expand Up @@ -759,6 +796,42 @@ def infer_type(self):
self.type = src_type


@define_op("select")
@dataclass
class SelectOp(CustomOp):
cond: fx.Node
if_true: fx.Node
if_false: fx.Node

@property
def indexing_dims(self) -> list[IndexSymbol]:
combined_dims = []
combined_dims += get_custom(self.cond).indexing_dims
combined_dims += get_custom(self.if_true).indexing_dims
combined_dims += get_custom(self.if_false).indexing_dims
return list(dict.fromkeys(combined_dims))

def infer_type(self):
cond_type = get_custom(self.cond).type
if_true_type = get_custom(self.if_true).type
if_false_type = get_custom(self.if_false).type

if cond_type.dtype != i1:
raise ValueError("SelectOp expects condition type to be i1.")

if if_true_type.dtype != if_false_type.dtype:
raise ValueError("SelectOp expects lhs and rhs dtype to match.")

# TODO: support broadcasting behavior.
if (
cond_type.symbolic_shape != if_true_type.symbolic_shape
or cond_type.symbolic_shape != if_false_type.symbolic_shape
):
raise ValueError("SelectOp doesn't support broadcasting. (yet?)")

self.type = if_true_type


@final
@dataclass
class Unknown(CustomOp):
Expand Down Expand Up @@ -940,6 +1013,22 @@ def type(self) -> "Memory":
return Memory[(*self.shape, self.address_space, self.dtype)]


@define_op("self_index")
@dataclass
class SelfIndex(CustomOp):
idx: IndexExpr
dtype: DataType
elements_per_thread: Optional[IndexExpr | int]

@property
def indexing_dims(self) -> list[IndexSymbol]:
return [self.idx]

@property
def type(self) -> "Register":
return Register[(self.idx, self.dtype)]


@define_op("shared_memory_barrier")
@dataclass
class SharedMemoryBarrier(CustomOp):
Expand Down
Loading