Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Nov 9, 2024
1 parent 4938412 commit 4eeae4b
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 27 deletions.
37 changes: 20 additions & 17 deletions mlx/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,24 +185,24 @@ array eval_impl(std::vector<array> outputs, bool async) {

auto stream = arr.primitive().stream();

// Lookup corresponding event and increment counter
Event e;
if (stream.threads > 1) {
auto e = Event(stream);
e.set_value(e.value() + 1);
arr.attach_event(e);
for (auto& s : arr.siblings()) {
s.attach_event(e);
}
// Use unique events for multi-threaded streams
e = Event(stream);
e.set_value(1);
} else {
auto e = events.find(stream.index);
if (e == events.end()) {
e = events.emplace(stream.index, Event{stream}).first;
}
e->second.set_value(e->second.value() + 1);
arr.attach_event(e->second);
for (auto& s : arr.siblings()) {
s.attach_event(e->second);
// Share events for single-threaded streams
auto e_it = events.find(stream.index);
if (e_it == events.end()) {
e_it = events.emplace(stream.index, Event{stream}).first;
}
e_it->second.set_value(e_it->second.value() + 1);
e = e_it->second;
}
// Increment event counter and attach to the array and siblings
arr.attach_event(e);
for (auto& s : arr.siblings()) {
s.attach_event(e);
}

// Set the status of the array and siblings.
Expand All @@ -212,7 +212,8 @@ array eval_impl(std::vector<array> outputs, bool async) {
}

std::vector<std::shared_future<void>> arr_deps;
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
bool signal =
stream.threads > 1 || needs_signal.find(arr.id()) != needs_signal.end();

if (arr.primitive().device() == Device::gpu) {
if (!metal::is_available()) {
Expand All @@ -222,7 +223,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
} else {
auto task = [arr = std::move(arr), stream, signal]() mutable {
for (auto& input : arr.inputs()) {
if (input.event().valid()) {
if (input.event().valid() &&
(stream.threads > 1 ||
input.event().stream() != arr.primitive().stream())) {
input.event().wait();
}
}
Expand Down
2 changes: 1 addition & 1 deletion mlx/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ std::ostream& operator<<(std::ostream& os, const Device& d) {
std::ostream& operator<<(std::ostream& os, const Stream& s) {
os << "Stream(";
os << s.device;
os << ", " << s.index << ")";
os << ", index=" << s.index << ", threads=" << s.threads << ")";
return os;
}

Expand Down
8 changes: 0 additions & 8 deletions python/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,6 @@ void init_array(nb::module_& m) {
R"pbdoc(
A helper object to apply updates at specific indices.
)pbdoc")
.def(
nb::init<const array&>(),
"x"_a,
nb::sig("def __init__(self, x: array)"))
.def("__getitem__", &ArrayAt::set_indices, "indices"_a.none())
.def("add", &ArrayAt::add, "value"_a)
.def("subtract", &ArrayAt::subtract, "value"_a)
Expand All @@ -202,10 +198,6 @@ void init_array(nb::module_& m) {
R"pbdoc(
A helper object to iterate over the 1st dimension of an array.
)pbdoc")
.def(
nb::init<const array&>(),
"x"_a,
nb::sig("def __init__(self, x: array)"))
.def("__next__", &ArrayPythonIterator::next)
.def("__iter__", [](const ArrayPythonIterator& it) { return it; });

Expand Down
4 changes: 3 additions & 1 deletion python/src/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ void init_stream(nb::module_& m) {
"Stream",
R"pbdoc(
A stream for running operations on a given device.
Use :func:`new_stream` to create new streams.
)pbdoc")
.def(nb::init<int, Device>(), "index"_a, "device"_a)
.def_ro("device", &Stream::device)
Expand Down Expand Up @@ -79,7 +81,7 @@ void init_stream(nb::module_& m) {
streams device. It will not change the default device.
Args:
stream (stream): Stream to make the default.
stream (Stream): Stream to make the default.
)pbdoc");
m.def(
"new_stream",
Expand Down
12 changes: 12 additions & 0 deletions python/tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,18 @@ def test_async_eval_with_multiple_streams(self):
mx.async_eval(x)
mx.eval(a + b)

def test_multithreaded_stream(self):
arrays = [mx.random.uniform(shape=(4, 4)) for _ in range(8)]
mx.eval(arrays)
s = mx.new_stream(mx.cpu, threads=2)
with mx.stream(s):
outputs = [mx.exp(-mx.abs(x)) for x in arrays]
out_multi = sum(outputs)

outputs = [mx.exp(-mx.abs(x)) for x in arrays]
out = sum(outputs)
self.assertTrue(mx.allclose(out, out_multi))


if __name__ == "__main__":
unittest.main()

0 comments on commit 4eeae4b

Please sign in to comment.