Skip to content

Commit

Permalink
Merge branch 'ershi/fix-more-tile-docs' into 'main'
Browse files Browse the repository at this point in the history
Fix various issues with Tile docs

See merge request omniverse/warp!940
  • Loading branch information
shi-eric committed Dec 20, 2024
2 parents 048587d + 429c281 commit c5b8568
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 74 deletions.
47 changes: 24 additions & 23 deletions docs/modules/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ Tile Primitives
:returns: A tile with ``shape=(m,n)`` and dtype the same as the source array


.. py:function:: tile_store(a: Array[Any], i: int32, t: Any) -> None
.. py:function:: tile_store(a: Array[Any], i: int32, t: Tile) -> None
Stores a 1D tile to a global memory array.

Expand All @@ -887,7 +887,7 @@ Tile Primitives
:param t: The source tile to store data from, must have the same dtype as the destination array


.. py:function:: tile_store(a: Array[Any], i: int32, j: int32, t: Any) -> None
.. py:function:: tile_store(a: Array[Any], i: int32, j: int32, t: Tile) -> None
:noindex:
:nocontentsentry:

Expand All @@ -901,7 +901,7 @@ Tile Primitives
:param t: The source tile to store data from, must have the same dtype as the destination array


.. py:function:: tile_atomic_add(a: Array[Any], x: int32, y: int32, t: Any) -> Tile
.. py:function:: tile_atomic_add(a: Array[Any], x: int32, y: int32, t: Tile) -> Tile
Atomically add a tile to the array `a`, each element will be updated atomically.

Expand Down Expand Up @@ -967,7 +967,7 @@ Tile Primitives
.. py:function:: untile(a: Any) -> Scalar
.. py:function:: untile(a: Tile) -> Scalar
Convert a Tile back to per-thread values.

Expand All @@ -991,7 +991,7 @@ Tile Primitives
t = wp.tile(i)*2
# convert back to per-thread values
s = wp.untile()
s = wp.untile(t)
print(s)
Expand Down Expand Up @@ -1038,7 +1038,7 @@ Tile Primitives
Broadcast a tile.

This method will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.
This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.

:param a: Tile to broadcast
:returns: Tile with broadcast ``shape=(m, n)``
Expand All @@ -1061,9 +1061,9 @@ Tile Primitives
t = wp.tile_ones(dtype=float, m=16, n=16)
s = wp.tile_sum(t)
print(t)
print(s)
wp.launch(compute, dim=[64], inputs=[])
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
Prints:

Expand All @@ -1088,18 +1088,19 @@ Tile Primitives
@wp.kernel
def compute():
t = wp.tile_arange(start=--10, stop=10, dtype=float)
t = wp.tile_arange(64, 128)
s = wp.tile_min(t)
print(t)
print(s)
wp.launch(compute, dim=[64], inputs=[])
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
Prints:

.. code-block:: text
tile(m=1, n=1, storage=register) = [[-10]]
tile(m=1, n=1, storage=register) = [[64 ]]
Expand All @@ -1118,23 +1119,23 @@ Tile Primitives
@wp.kernel
def compute():
t = wp.tile_arange(start=--10, stop=10, dtype=float)
s = wp.tile_min(t)
t = wp.tile_arange(64, 128)
s = wp.tile_max(t)
print(t)
print(s)
wp.launch(compute, dim=[64], inputs=[])
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
Prints:

.. code-block:: text
tile(m=1, n=1, storage=register) = [[10]]
tile(m=1, n=1, storage=register) = [[127 ]]
.. py:function:: tile_reduce(op: Callable, a: Any) -> Tile
.. py:function:: tile_reduce(op: Callable, a: Tile) -> Tile
Apply a custom reduction operator across the tile.

Expand All @@ -1156,7 +1157,7 @@ Tile Primitives
print(s)
wp.launch(factorial, dim=[16], inputs=[], block_dim=16)
wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16)
Prints:

Expand All @@ -1166,7 +1167,7 @@ Tile Primitives
.. py:function:: tile_map(op: Callable, a: Any) -> Tile
.. py:function:: tile_map(op: Callable, a: Tile) -> Tile
Apply a unary function onto the tile.

Expand All @@ -1188,7 +1189,7 @@ Tile Primitives
print(s)
wp.launch(compute, dim=[16], inputs=[])
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
Prints:

Expand All @@ -1198,7 +1199,7 @@ Tile Primitives
.. py:function:: tile_map(op: Callable, a: Any, b: Any) -> Tile
.. py:function:: tile_map(op: Callable, a: Tile, b: Tile) -> Tile
:noindex:
:nocontentsentry:

Expand Down Expand Up @@ -1226,7 +1227,7 @@ Tile Primitives
print(s)
wp.launch(compute, dim=[16], inputs=[])
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
Prints:

Expand Down
48 changes: 25 additions & 23 deletions warp/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,6 +1852,7 @@ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
step = args[2]

if start is None or stop is None or step is None:
print(args)
raise RuntimeError("wp.tile_arange() arguments must be compile time constants")

if "dtype" in arg_values:
Expand Down Expand Up @@ -2083,7 +2084,7 @@ def tile_store_1d_value_func(arg_types, arg_values):

add_builtin(
"tile_store",
input_types={"a": array(dtype=Any), "i": int, "t": Any},
input_types={"a": array(dtype=Any), "i": int, "t": Tile(dtype=Any, M=Any, N=Any)},
value_func=tile_store_1d_value_func,
variadic=False,
skip_replay=True,
Expand Down Expand Up @@ -2132,7 +2133,7 @@ def tile_store_2d_value_func(arg_types, arg_values):

add_builtin(
"tile_store",
input_types={"a": array(dtype=Any), "i": int, "j": int, "t": Any},
input_types={"a": array(dtype=Any), "i": int, "j": int, "t": Tile(dtype=Any, M=Any, N=Any)},
value_func=tile_store_2d_value_func,
variadic=False,
skip_replay=True,
Expand Down Expand Up @@ -2177,7 +2178,7 @@ def tile_atomic_add_value_func(arg_types, arg_values):

add_builtin(
"tile_atomic_add",
input_types={"a": array(dtype=Any), "x": int, "y": int, "t": Any},
input_types={"a": array(dtype=Any), "x": int, "y": int, "t": Tile(dtype=Any, M=Any, N=Any)},
value_func=tile_atomic_add_value_func,
variadic=True,
skip_replay=True,
Expand Down Expand Up @@ -2365,7 +2366,7 @@ def untile_value_func(arg_types, arg_values):

add_builtin(
"untile",
input_types={"a": Any},
input_types={"a": Tile(dtype=Any, M=Any, N=Any)},
value_func=untile_value_func,
variadic=True,
doc="""Convert a Tile back to per-thread values.
Expand All @@ -2390,7 +2391,7 @@ def compute():
t = wp.tile(i)*2
# convert back to per-thread values
s = wp.untile()
s = wp.untile(t)
print(s)
Expand Down Expand Up @@ -2562,7 +2563,7 @@ def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any
variadic=True,
doc="""Broadcast a tile.
This method will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.
This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.
:param a: Tile to broadcast
:returns: Tile with broadcast ``shape=(m, n)``""",
Expand Down Expand Up @@ -2654,9 +2655,9 @@ def compute():
t = wp.tile_ones(dtype=float, m=16, n=16)
s = wp.tile_sum(t)
print(t)
print(s)
wp.launch(compute, dim=[64], inputs=[])
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
Prints:
Expand Down Expand Up @@ -2703,18 +2704,19 @@ def tile_min_value_func(arg_types, arg_values):
@wp.kernel
def compute():
t = wp.tile_arange(start=--10, stop=10, dtype=float)
t = wp.tile_arange(64, 128)
s = wp.tile_min(t)
print(t)
print(s)
wp.launch(compute, dim=[64], inputs=[])
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
Prints:
.. code-block:: text
tile(m=1, n=1, storage=register) = [[-10]]
tile(m=1, n=1, storage=register) = [[64 ]]
""",
group="Tile Primitives",
Expand Down Expand Up @@ -2755,18 +2757,18 @@ def tile_max_value_func(arg_types, arg_values):
@wp.kernel
def compute():
t = wp.tile_arange(start=--10, stop=10, dtype=float)
s = wp.tile_min(t)
t = wp.tile_arange(64, 128)
s = wp.tile_max(t)
print(t)
print(s)
wp.launch(compute, dim=[64], inputs=[])
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
Prints:
.. code-block:: text
tile(m=1, n=1, storage=register) = [[10]]
tile(m=1, n=1, storage=register) = [[127 ]]
""",
group="Tile Primitives",
Expand Down Expand Up @@ -2796,7 +2798,7 @@ def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any,

add_builtin(
"tile_reduce",
input_types={"op": Callable, "a": Any},
input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any)},
value_func=tile_reduce_value_func,
native_func="tile_reduce",
doc="""Apply a custom reduction operator across the tile.
Expand All @@ -2819,7 +2821,7 @@ def factorial():
print(s)
wp.launch(factorial, dim=[16], inputs=[], block_dim=16)
wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16)
Prints:
Expand Down Expand Up @@ -2856,7 +2858,7 @@ def tile_unary_map_value_func(arg_types, arg_values):

add_builtin(
"tile_map",
input_types={"op": Callable, "a": Any},
input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any)},
value_func=tile_unary_map_value_func,
# dispatch_func=tile_map_dispatch_func,
# variadic=True,
Expand All @@ -2881,7 +2883,7 @@ def compute():
print(s)
wp.launch(compute, dim=[16], inputs=[])
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
Prints:
Expand Down Expand Up @@ -2923,7 +2925,7 @@ def tile_binary_map_value_func(arg_types, arg_values):

add_builtin(
"tile_map",
input_types={"op": Callable, "a": Any, "b": Any},
input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)},
value_func=tile_binary_map_value_func,
# dispatch_func=tile_map_dispatch_func,
# variadic=True,
Expand Down Expand Up @@ -2952,7 +2954,7 @@ def compute():
print(s)
wp.launch(compute, dim=[16], inputs=[])
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
Prints:
Expand Down
7 changes: 2 additions & 5 deletions warp/native/tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -1125,8 +1125,6 @@ inline CUDA_CALLABLE auto untile(Tile& tile)
}
}



template <typename Tile, typename Value>
inline CUDA_CALLABLE void adj_untile(Tile& tile, Tile& adj_tile, Value& adj_ret)
{
Expand Down Expand Up @@ -1156,15 +1154,15 @@ inline CUDA_CALLABLE auto tile_zeros()
return T(0);
}

// zero initialized tile
// one-initialized tile
template <typename T, int M, int N>
inline CUDA_CALLABLE auto tile_ones()
{
// tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
return T(1);
}

// zero initialized tile
// tile with evenly spaced values
template <typename T, int M, int N>
inline CUDA_CALLABLE auto tile_arange(T start, T stop, T step)
{
Expand Down Expand Up @@ -1220,7 +1218,6 @@ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src)
src.copy_to_global(dest, x, y);
}

// entry point for store
template <typename T, typename Tile>
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, Tile& src)
{
Expand Down
Loading

0 comments on commit c5b8568

Please sign in to comment.