Skip to content

Commit

Permalink
attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Nov 7, 2024
1 parent 9a3842a commit 6c0ad3f
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 26 deletions.
8 changes: 2 additions & 6 deletions mlx/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,12 @@ void set_default_stream(Stream s) {
return scheduler::scheduler().set_default_stream(s);
}

Stream new_stream(Device d) {
Stream new_stream(Device d, int threads /* = 1 */) {
if (!metal::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[new_stream] Cannot make gpu stream without gpu backend.");
}
return scheduler::scheduler().new_stream(d);
}

Stream new_stream() {
return scheduler::scheduler().new_stream(default_device());
return scheduler::scheduler().new_stream(d, threads);
}

void synchronize(Stream s) {
Expand Down
21 changes: 13 additions & 8 deletions mlx/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ struct StreamThread {
std::condition_variable cond;
bool stop;
Stream stream;
std::thread thread;
std::vector<std::thread> threads;

StreamThread(Stream stream)
: stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {
StreamThread(Stream stream, int num_threads = 1)
: stop(false), stream(stream) {
metal::new_stream(stream);
for (int i = 0; i < num_threads; ++i) {
threads.emplace_back(&StreamThread::thread_fn, this);
}
}

~StreamThread() {
Expand All @@ -34,8 +37,10 @@ struct StreamThread {
std::lock_guard<std::mutex> lk(mtx);
stop = true;
}
cond.notify_one();
thread.join();
cond.notify_all();
for (auto& t : threads) {
t.join();
}
}

void thread_fn() {
Expand Down Expand Up @@ -84,9 +89,9 @@ class Scheduler {
Scheduler& operator=(const Scheduler&) = delete;
Scheduler& operator=(Scheduler&&) = delete;

Stream new_stream(const Device& d) {
auto stream = Stream(streams_.size(), d);
streams_.push_back(new StreamThread{stream});
Stream new_stream(const Device& d, int threads = 1) {
auto stream = Stream(streams_.size(), d, threads);
streams_.push_back(new StreamThread{stream, threads});
return stream;
}

Expand Down
6 changes: 4 additions & 2 deletions mlx/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ namespace mlx::core {
struct Stream {
int index;
Device device;
explicit Stream(int index, Device device) : index(index), device(device) {}
int threads;
explicit Stream(int index, Device device, int threads = 1)
: index(index), device(device), threads(threads) {}
};

/** Get the default stream for the given device. */
Expand All @@ -19,7 +21,7 @@ Stream default_stream(Device d);
void set_default_stream(Stream s);

/** Make a new stream on the given device. */
Stream new_stream(Device d);
Stream new_stream(Device d, int threads = 1);

inline bool operator==(const Stream& lhs, const Stream& rhs) {
return lhs.index == rhs.index;
Expand Down
28 changes: 18 additions & 10 deletions mlx/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,23 @@ array eval_impl(std::vector<array> outputs, bool async) {
auto stream = arr.primitive().stream();

// Lookup corresponding event and increment counter
auto e = events.find(stream.index);
if (e == events.end()) {
e = events.emplace(stream.index, Event{stream}).first;
}
e->second.set_value(e->second.value() + 1);
arr.attach_event(e->second);
for (auto& s : arr.siblings()) {
s.attach_event(e->second);
if (stream.threads > 1) {
auto e = Event(stream);
e.set_value(e.value() + 1);
arr.attach_event(e);
for (auto& s : arr.siblings()) {
s.attach_event(e);
}
} else {
auto e = events.find(stream.index);
if (e == events.end()) {
e = events.emplace(stream.index, Event{stream}).first;
}
e->second.set_value(e->second.value() + 1);
arr.attach_event(e->second);
for (auto& s : arr.siblings()) {
s.attach_event(e->second);
}
}

// Set the status of the array and siblings.
Expand All @@ -193,8 +202,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
} else {
auto task = [arr = std::move(arr), stream, signal]() mutable {
for (auto& input : arr.inputs()) {
if (input.event().valid() &&
input.event().stream() != arr.primitive().stream()) {
if (input.event().valid()) {
input.event().wait();
}
}
Expand Down
1 change: 1 addition & 0 deletions python/src/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ void init_stream(nb::module_& m) {
"new_stream",
&new_stream,
"device"_a,
"threads"_a = 1,
R"pbdoc(Make a new stream on the given device.)pbdoc");

nb::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
Expand Down

0 comments on commit 6c0ad3f

Please sign in to comment.