Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support outer reduction scheduler with SOL autotuning #3618

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
27 changes: 27 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <scheduler/utils.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <transform_replay.h>
#include <utils.h>
#include <iostream>
#include <optional>
#include <tuple>
Expand Down Expand Up @@ -761,6 +762,32 @@ void initNvFuserPythonBindings(PyObject* module) {

nvfuser.def("clone", clone);

nvfuser.def(
"get_registers_per_thread",
getRegPerThreadGivenThreadsPerSM,
py::arg("threads_per_sm"),
R"(
Estimate the number of registers per thread using cuda occupancy API.

Parameters
----------
threads_per_sm : int
The number of threads per SM.
)");

nvfuser.def(
"get_threads_per_sm",
getThreadsPerSMGivenRegPerThread,
py::arg("reg_per_thread"),
R"(
Get number of threads per sm using cuda occupancy API.

Parameters
----------
reg_per_thread : int
The number of registers per thread.
)");

//! DataTypes supported by nvFuser in the FusionDefinition
py::enum_<PrimDataType>(nvfuser, "DataType")
.value("Double", DataType::Double)
Expand Down
2 changes: 1 addition & 1 deletion doc/dev/python_scheduling/autotune_inner_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def inner_fn(num_inputs):
assert False

# A decorator to create a reduction fusion given some input arguments.
def create_fusion_func(self, inputs):
def create_fusion_func(self):
def sum_fusion(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
shape=[-1, -1],
Expand Down
Loading
Loading