| 
1 | 1 | import ctypes  | 
 | 2 | +from types import EllipsisType  | 
2 | 3 | 
 
  | 
3 | 4 | import mlir.execution_engine  | 
4 | 5 | import mlir.passmanager  | 
@@ -85,12 +86,39 @@ def get_reshape_module(  | 
85 | 86 |             def reshape(a, shape):  | 
86 | 87 |                 return tensor.reshape(out_tensor_type, a, shape)  | 
87 | 88 | 
 
  | 
88 |  | -            reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()  | 
89 |  | -            if DEBUG:  | 
90 |  | -                (CWD / "reshape_module.mlir").write_text(str(module))  | 
91 |  | -            pm.run(module.operation)  | 
92 |  | -            if DEBUG:  | 
93 |  | -                (CWD / "reshape_module_opt.mlir").write_text(str(module))  | 
 | 89 | +        reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()  | 
 | 90 | +        if DEBUG:  | 
 | 91 | +            (CWD / "reshape_module.mlir").write_text(str(module))  | 
 | 92 | +        pm.run(module.operation)  | 
 | 93 | +        if DEBUG:  | 
 | 94 | +            (CWD / "reshape_module_opt.mlir").write_text(str(module))  | 
 | 95 | + | 
 | 96 | +    return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])  | 
 | 97 | + | 
 | 98 | + | 
 | 99 | +@fn_cache  | 
 | 100 | +def get_slice_module(  | 
 | 101 | +    in_tensor_type: ir.RankedTensorType,  | 
 | 102 | +    out_tensor_type: ir.RankedTensorType,  | 
 | 103 | +    offsets: tuple[int, ...],  | 
 | 104 | +    sizes: tuple[int, ...],  | 
 | 105 | +    strides: tuple[int, ...],  | 
 | 106 | +) -> ir.Module:  | 
 | 107 | +    with ir.Location.unknown(ctx):  | 
 | 108 | +        module = ir.Module.create()  | 
 | 109 | + | 
 | 110 | +        with ir.InsertionPoint(module.body):  | 
 | 111 | + | 
 | 112 | +            @func.FuncOp.from_py_func(in_tensor_type)  | 
 | 113 | +            def getitem(a):  | 
 | 114 | +                return tensor.extract_slice(out_tensor_type, a, [], [], [], offsets, sizes, strides)  | 
 | 115 | + | 
 | 116 | +        getitem.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()  | 
 | 117 | +        if DEBUG:  | 
 | 118 | +            (CWD / "getitem_module.mlir").write_text(str(module))  | 
 | 119 | +        pm.run(module.operation)  | 
 | 120 | +        if DEBUG:  | 
 | 121 | +            (CWD / "getitem_module_opt.mlir").write_text(str(module))  | 
94 | 122 | 
 
  | 
95 | 123 |     return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])  | 
96 | 124 | 
 
  | 
@@ -152,3 +180,80 @@ def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor:  | 
152 | 180 |     )  | 
153 | 181 | 
 
  | 
154 | 182 |     return Tensor(ret_obj, shape=out_tensor_type.shape)  | 
 | 183 | + | 
 | 184 | + | 
 | 185 | +def _add_missing_dims(key: tuple, ndim: int) -> tuple:  | 
 | 186 | +    if len(key) < ndim and Ellipsis not in key:  | 
 | 187 | +        return key + (...,)  | 
 | 188 | +    return key  | 
 | 189 | + | 
 | 190 | + | 
 | 191 | +def _expand_ellipsis(key: tuple, ndim: int) -> tuple:  | 
 | 192 | +    if Ellipsis in key:  | 
 | 193 | +        if len([e for e in key if e is Ellipsis]) > 1:  | 
 | 194 | +            raise Exception(f"Ellipsis should be used once: {key}")  | 
 | 195 | +        to_expand = ndim - len(key) + 1  | 
 | 196 | +        if to_expand <= 0:  | 
 | 197 | +            raise Exception(f"Invalid use of Ellipsis in {key}")  | 
 | 198 | +        idx = key.index(Ellipsis)  | 
 | 199 | +        return key[:idx] + tuple(slice(None) for _ in range(to_expand)) + key[idx + 1 :]  | 
 | 200 | +    return key  | 
 | 201 | + | 
 | 202 | + | 
 | 203 | +def _decompose_slices(  | 
 | 204 | +    key: tuple,  | 
 | 205 | +    shape: tuple[int, ...],  | 
 | 206 | +) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:  | 
 | 207 | +    offsets = []  | 
 | 208 | +    sizes = []  | 
 | 209 | +    strides = []  | 
 | 210 | + | 
 | 211 | +    for key_elem, size in zip(key, shape, strict=False):  | 
 | 212 | +        if isinstance(key_elem, slice):  | 
 | 213 | +            offset = key_elem.start if key_elem.start is not None else 0  | 
 | 214 | +            size = key_elem.stop - offset if key_elem.stop is not None else size - offset  | 
 | 215 | +            stride = key_elem.step if key_elem.step is not None else 1  | 
 | 216 | +        elif isinstance(key_elem, int):  | 
 | 217 | +            offset = key_elem  | 
 | 218 | +            size = key_elem + 1  | 
 | 219 | +            stride = 1  | 
 | 220 | +        offsets.append(offset)  | 
 | 221 | +        sizes.append(size)  | 
 | 222 | +        strides.append(stride)  | 
 | 223 | + | 
 | 224 | +    return tuple(offsets), tuple(sizes), tuple(strides)  | 
 | 225 | + | 
 | 226 | + | 
 | 227 | +def _get_new_shape(sizes, strides) -> tuple[int, ...]:  | 
 | 228 | +    return tuple(size // stride for size, stride in zip(sizes, strides, strict=False))  | 
 | 229 | + | 
 | 230 | + | 
 | 231 | +def getitem(  | 
 | 232 | +    x: Tensor,  | 
 | 233 | +    key: int | slice | EllipsisType | tuple[int | slice | EllipsisType, ...],  | 
 | 234 | +) -> Tensor:  | 
 | 235 | +    if not isinstance(key, tuple):  | 
 | 236 | +        key = (key,)  | 
 | 237 | +    if None in key:  | 
 | 238 | +        raise Exception(f"Lazy indexing isn't supported: {key}")  | 
 | 239 | + | 
 | 240 | +    ret_obj = x._format_class()  | 
 | 241 | + | 
 | 242 | +    key = _add_missing_dims(key, x.ndim)  | 
 | 243 | +    key = _expand_ellipsis(key, x.ndim)  | 
 | 244 | +    offsets, sizes, strides = _decompose_slices(key, x.shape)  | 
 | 245 | + | 
 | 246 | +    new_shape = _get_new_shape(sizes, strides)  | 
 | 247 | +    out_tensor_type = x._obj.get_tensor_definition(new_shape)  | 
 | 248 | + | 
 | 249 | +    slice_module = get_slice_module(  | 
 | 250 | +        x._obj.get_tensor_definition(x.shape),  | 
 | 251 | +        out_tensor_type,  | 
 | 252 | +        offsets,  | 
 | 253 | +        sizes,  | 
 | 254 | +        strides,  | 
 | 255 | +    )  | 
 | 256 | + | 
 | 257 | +    slice_module.invoke("getitem", ctypes.pointer(ctypes.pointer(ret_obj)), *x._obj.to_module_arg())  | 
 | 258 | + | 
 | 259 | +    return Tensor(ret_obj, shape=out_tensor_type.shape)  | 
0 commit comments