Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Multithread stream #1570

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions mlx/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ void set_default_stream(Stream s) {
return scheduler::scheduler().set_default_stream(s);
}

Stream new_stream(Device d) {
Stream new_stream(Device d, int threads /* = 1 */) {
if (!metal::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[new_stream] Cannot make gpu stream without gpu backend.");
}
return scheduler::scheduler().new_stream(d);
}

Stream new_stream() {
return scheduler::scheduler().new_stream(default_device());
if (d == Device::gpu && threads > 1) {
throw std::invalid_argument(
"[new_stream] Cannot make multi-threaded gpu stream.");
}
return scheduler::scheduler().new_stream(d, threads);
}

void synchronize(Stream s) {
Expand Down
21 changes: 13 additions & 8 deletions mlx/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ struct StreamThread {
std::condition_variable cond;
bool stop;
Stream stream;
std::thread thread;
std::vector<std::thread> threads;

StreamThread(Stream stream)
: stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {
StreamThread(Stream stream, int num_threads = 1)
: stop(false), stream(stream) {
metal::new_stream(stream);
for (int i = 0; i < num_threads; ++i) {
threads.emplace_back(&StreamThread::thread_fn, this);
}
}

~StreamThread() {
Expand All @@ -34,8 +37,10 @@ struct StreamThread {
std::lock_guard<std::mutex> lk(mtx);
stop = true;
}
cond.notify_one();
thread.join();
cond.notify_all();
for (auto& t : threads) {
t.join();
}
}

void thread_fn() {
Expand Down Expand Up @@ -84,9 +89,9 @@ class Scheduler {
Scheduler& operator=(const Scheduler&) = delete;
Scheduler& operator=(Scheduler&&) = delete;

Stream new_stream(const Device& d) {
auto stream = Stream(streams_.size(), d);
streams_.push_back(new StreamThread{stream});
Stream new_stream(const Device& d, int threads = 1) {
auto stream = Stream(streams_.size(), d, threads);
streams_.push_back(new StreamThread{stream, threads});
return stream;
}

Expand Down
6 changes: 4 additions & 2 deletions mlx/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ namespace mlx::core {
struct Stream {
int index;
Device device;
explicit Stream(int index, Device device) : index(index), device(device) {}
int threads;
explicit Stream(int index, Device device, int threads = 1)
: index(index), device(device), threads(threads) {}
};

/** Get the default stream for the given device. */
Expand All @@ -19,7 +21,7 @@ Stream default_stream(Device d);
void set_default_stream(Stream s);

/** Make a new stream on the given device. */
Stream new_stream(Device d);
Stream new_stream(Device d, int threads = 1);

inline bool operator==(const Stream& lhs, const Stream& rhs) {
return lhs.index == rhs.index;
Expand Down
29 changes: 20 additions & 9 deletions mlx/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,24 @@ array eval_impl(std::vector<array> outputs, bool async) {

auto stream = arr.primitive().stream();

// Lookup corresponding event and increment counter
auto e = events.find(stream.index);
if (e == events.end()) {
e = events.emplace(stream.index, Event{stream}).first;
Event e;
if (stream.threads > 1) {
// Use unique events for multi-threaded streams
e = Event(stream);
e.set_value(1);
} else {
// Share events for single-threaded streams
auto e_it = events.find(stream.index);
if (e_it == events.end()) {
e_it = events.emplace(stream.index, Event{stream}).first;
}
e_it->second.set_value(e_it->second.value() + 1);
e = e_it->second;
}
e->second.set_value(e->second.value() + 1);
arr.attach_event(e->second);
// Increment event counter and attach to the array and siblings
arr.attach_event(e);
for (auto& s : arr.siblings()) {
s.attach_event(e->second);
s.attach_event(e);
}

// Set the status of the array and siblings.
Expand All @@ -203,7 +212,8 @@ array eval_impl(std::vector<array> outputs, bool async) {
}

std::vector<std::shared_future<void>> arr_deps;
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
bool signal =
stream.threads > 1 || needs_signal.find(arr.id()) != needs_signal.end();

if (arr.primitive().device() == Device::gpu) {
if (!metal::is_available()) {
Expand All @@ -214,7 +224,8 @@ array eval_impl(std::vector<array> outputs, bool async) {
auto task = [arr = std::move(arr), stream, signal]() mutable {
for (auto& input : arr.inputs()) {
if (input.event().valid() &&
input.event().stream() != arr.primitive().stream()) {
(stream.threads > 1 ||
input.event().stream() != arr.primitive().stream())) {
input.event().wait();
}
}
Expand Down
2 changes: 1 addition & 1 deletion mlx/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ std::ostream& operator<<(std::ostream& os, const Device& d) {
std::ostream& operator<<(std::ostream& os, const Stream& s) {
os << "Stream(";
os << s.device;
os << ", " << s.index << ")";
os << ", index=" << s.index << ", threads=" << s.threads << ")";
return os;
}

Expand Down
8 changes: 0 additions & 8 deletions python/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,6 @@ void init_array(nb::module_& m) {
R"pbdoc(
A helper object to apply updates at specific indices.
)pbdoc")
.def(
nb::init<const array&>(),
"x"_a,
nb::sig("def __init__(self, x: array)"))
.def("__getitem__", &ArrayAt::set_indices, "indices"_a.none())
.def("add", &ArrayAt::add, "value"_a)
.def("subtract", &ArrayAt::subtract, "value"_a)
Expand All @@ -202,10 +198,6 @@ void init_array(nb::module_& m) {
R"pbdoc(
A helper object to iterate over the 1st dimension of an array.
)pbdoc")
.def(
nb::init<const array&>(),
"x"_a,
nb::sig("def __init__(self, x: array)"))
.def("__next__", &ArrayPythonIterator::next)
.def("__iter__", [](const ArrayPythonIterator& it) { return it; });

Expand Down
5 changes: 4 additions & 1 deletion python/src/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ void init_stream(nb::module_& m) {
"Stream",
R"pbdoc(
A stream for running operations on a given device.

Use :func:`new_stream` to create new streams.
)pbdoc")
.def(nb::init<int, Device>(), "index"_a, "device"_a)
.def_ro("device", &Stream::device)
Expand Down Expand Up @@ -79,12 +81,13 @@ void init_stream(nb::module_& m) {
streams device. It will not change the default device.

Args:
stream (stream): Stream to make the default.
stream (Stream): Stream to make the default.
)pbdoc");
m.def(
"new_stream",
&new_stream,
"device"_a,
"threads"_a = 1,
R"pbdoc(Make a new stream on the given device.)pbdoc");

nb::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
Expand Down
12 changes: 12 additions & 0 deletions python/tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,18 @@ def test_async_eval_with_multiple_streams(self):
mx.async_eval(x)
mx.eval(a + b)

def test_multithreaded_stream(self):
arrays = [mx.random.uniform(shape=(4, 4)) for _ in range(8)]
mx.eval(arrays)
s = mx.new_stream(mx.cpu, threads=2)
with mx.stream(s):
outputs = [mx.exp(-mx.abs(x)) for x in arrays]
out_multi = sum(outputs)

outputs = [mx.exp(-mx.abs(x)) for x in arrays]
out = sum(outputs)
self.assertTrue(mx.allclose(out, out_multi))


if __name__ == "__main__":
unittest.main()