Skip to content

Commit 268191f

Browse files
Ilia Cherniavskiifacebook-github-bot
authored andcommitted
Use global thread pool in async_scheduling
Summary: Simplify async_scheduling to use global thread pool instead of per network polling threads Reviewed By: romain-intel Differential Revision: D6814274 fbshipit-source-id: f91ac3e99d9b8cf15578a751ed7929be84840408
1 parent d083c62 commit 268191f

File tree

2 files changed

+14
-46
lines changed

2 files changed

+14
-46
lines changed

caffe2/core/net_async_scheduling.cc

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,40 +21,19 @@ CAFFE2_DEFINE_bool(
2121
false,
2222
"Always schedule child chains from parent chain");
2323

24-
CAFFE2_DEFINE_int(
25-
caffe2_net_async_polling_threads_num,
26-
1,
27-
"Number of polling threads in async_scheduling executor");
28-
2924
namespace caffe2 {
3025

3126
AsyncSchedulingNet::AsyncSchedulingNet(
3227
const std::shared_ptr<const NetDef>& net_def,
3328
Workspace* ws)
3429
: AsyncNetBase(net_def, ws), running_(false) {
35-
pending_tasks_.reserve(FLAGS_caffe2_net_async_polling_threads_num);
36-
for (auto thread_num = 0;
37-
thread_num < FLAGS_caffe2_net_async_polling_threads_num;
38-
++thread_num) {
39-
pending_tasks_.push_back(caffe2::make_unique<SimpleQueue<int>>());
40-
}
41-
42-
polling_threads_.reserve(FLAGS_caffe2_net_async_polling_threads_num);
43-
for (auto thread_num = 0;
44-
thread_num < FLAGS_caffe2_net_async_polling_threads_num;
45-
++thread_num) {
46-
polling_threads_.push_back(
47-
std::thread(&AsyncSchedulingNet::pollAndSchedule, this, thread_num));
48-
}
49-
5030
reset();
5131
}
5232

5333
void AsyncSchedulingNet::reset() {
5434
processed_tasks_num_ = 0;
5535
cleanup_ = false;
5636
success_ = true;
57-
next_polling_thread_counter_ = 0;
5837

5938
for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
6039
auto& task_ops = chains_[task_id];
@@ -90,9 +69,10 @@ void AsyncSchedulingNet::schedule(int task_id) {
9069
canSchedule(child_id)) {
9170
schedule(child_id);
9271
} else {
93-
auto polling_thread_id = next_polling_thread_counter_++;
94-
polling_thread_id %= FLAGS_caffe2_net_async_polling_threads_num;
95-
pending_tasks_[polling_thread_id]->Push(child_id);
72+
const auto& device_option = event(child_id).GetDeviceOption();
73+
pool(device_option)
74+
->run(std::bind(
75+
&AsyncSchedulingNet::pollAndSchedule, this, child_id));
9676
}
9777
}
9878
}
@@ -133,15 +113,14 @@ void AsyncSchedulingNet::schedule(int task_id) {
133113
});
134114
}
135115

136-
void AsyncSchedulingNet::pollAndSchedule(int thread_id) {
137-
int task_id;
138-
while (pending_tasks_[thread_id]->Pop(&task_id)) {
139-
if (canSchedule(task_id) || cleanup_) {
140-
// force schedule the rest of the tasks if cleanup is started
141-
schedule(task_id);
142-
} else {
143-
pending_tasks_[thread_id]->Push(task_id);
144-
}
116+
void AsyncSchedulingNet::pollAndSchedule(int task_id) {
117+
if (canSchedule(task_id) || cleanup_) {
118+
// force schedule the rest of the tasks if cleanup is started
119+
schedule(task_id);
120+
} else {
121+
const auto& device_option = event(task_id).GetDeviceOption();
122+
pool(device_option)
123+
->run(std::bind(&AsyncSchedulingNet::pollAndSchedule, this, task_id));
145124
}
146125
}
147126

@@ -177,14 +156,7 @@ bool AsyncSchedulingNet::DoRunAsync() {
177156
return true;
178157
}
179158

180-
AsyncSchedulingNet::~AsyncSchedulingNet() {
181-
for (auto& task_queue : pending_tasks_) {
182-
task_queue->NoMoreJobs();
183-
}
184-
for (auto& polling_thread : polling_threads_) {
185-
polling_thread.join();
186-
}
187-
}
159+
AsyncSchedulingNet::~AsyncSchedulingNet() {}
188160

189161
REGISTER_NET(async_scheduling, AsyncSchedulingNet);
190162

caffe2/core/net_async_scheduling.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class AsyncSchedulingNet : public AsyncNetBase {
3333
protected:
3434
bool DoRunAsync() override;
3535

36-
void pollAndSchedule(int thread_id);
36+
void pollAndSchedule(int task_id);
3737
void schedule(int task_id);
3838
void reset();
3939
void finishRun();
@@ -49,10 +49,6 @@ class AsyncSchedulingNet : public AsyncNetBase {
4949

5050
std::atomic<int> processed_tasks_num_;
5151

52-
std::vector<std::unique_ptr<SimpleQueue<int>>> pending_tasks_;
53-
std::vector<std::thread> polling_threads_;
54-
std::atomic<int> next_polling_thread_counter_;
55-
5652
DISABLE_COPY_AND_ASSIGN(AsyncSchedulingNet);
5753
};
5854

0 commit comments

Comments
 (0)