@@ -21,40 +21,19 @@ CAFFE2_DEFINE_bool(
21
21
false ,
22
22
" Always schedule child chains from parent chain" );
23
23
24
- CAFFE2_DEFINE_int (
25
- caffe2_net_async_polling_threads_num,
26
- 1 ,
27
- " Number of polling threads in async_scheduling executor" );
28
-
29
24
namespace caffe2 {
30
25
31
26
AsyncSchedulingNet::AsyncSchedulingNet (
32
27
const std::shared_ptr<const NetDef>& net_def,
33
28
Workspace* ws)
34
29
: AsyncNetBase(net_def, ws), running_(false ) {
35
- pending_tasks_.reserve (FLAGS_caffe2_net_async_polling_threads_num);
36
- for (auto thread_num = 0 ;
37
- thread_num < FLAGS_caffe2_net_async_polling_threads_num;
38
- ++thread_num) {
39
- pending_tasks_.push_back (caffe2::make_unique<SimpleQueue<int >>());
40
- }
41
-
42
- polling_threads_.reserve (FLAGS_caffe2_net_async_polling_threads_num);
43
- for (auto thread_num = 0 ;
44
- thread_num < FLAGS_caffe2_net_async_polling_threads_num;
45
- ++thread_num) {
46
- polling_threads_.push_back (
47
- std::thread (&AsyncSchedulingNet::pollAndSchedule, this , thread_num));
48
- }
49
-
50
30
reset ();
51
31
}
52
32
53
33
void AsyncSchedulingNet::reset () {
54
34
processed_tasks_num_ = 0 ;
55
35
cleanup_ = false ;
56
36
success_ = true ;
57
- next_polling_thread_counter_ = 0 ;
58
37
59
38
for (auto task_id = 0 ; task_id < tasksNum (); ++task_id) {
60
39
auto & task_ops = chains_[task_id];
@@ -90,9 +69,10 @@ void AsyncSchedulingNet::schedule(int task_id) {
90
69
canSchedule (child_id)) {
91
70
schedule (child_id);
92
71
} else {
93
- auto polling_thread_id = next_polling_thread_counter_++;
94
- polling_thread_id %= FLAGS_caffe2_net_async_polling_threads_num;
95
- pending_tasks_[polling_thread_id]->Push (child_id);
72
+ const auto & device_option = event (child_id).GetDeviceOption ();
73
+ pool (device_option)
74
+ ->run (std::bind (
75
+ &AsyncSchedulingNet::pollAndSchedule, this , child_id));
96
76
}
97
77
}
98
78
}
@@ -133,15 +113,14 @@ void AsyncSchedulingNet::schedule(int task_id) {
133
113
});
134
114
}
135
115
136
- void AsyncSchedulingNet::pollAndSchedule (int thread_id) {
137
- int task_id;
138
- while (pending_tasks_[thread_id]->Pop (&task_id)) {
139
- if (canSchedule (task_id) || cleanup_) {
140
- // force schedule the rest of the tasks if cleanup is started
141
- schedule (task_id);
142
- } else {
143
- pending_tasks_[thread_id]->Push (task_id);
144
- }
116
+ void AsyncSchedulingNet::pollAndSchedule (int task_id) {
117
+ if (canSchedule (task_id) || cleanup_) {
118
+ // force schedule the rest of the tasks if cleanup is started
119
+ schedule (task_id);
120
+ } else {
121
+ const auto & device_option = event (task_id).GetDeviceOption ();
122
+ pool (device_option)
123
+ ->run (std::bind (&AsyncSchedulingNet::pollAndSchedule, this , task_id));
145
124
}
146
125
}
147
126
@@ -177,14 +156,7 @@ bool AsyncSchedulingNet::DoRunAsync() {
177
156
return true ;
178
157
}
179
158
180
- AsyncSchedulingNet::~AsyncSchedulingNet () {
181
- for (auto & task_queue : pending_tasks_) {
182
- task_queue->NoMoreJobs ();
183
- }
184
- for (auto & polling_thread : polling_threads_) {
185
- polling_thread.join ();
186
- }
187
- }
159
+ AsyncSchedulingNet::~AsyncSchedulingNet () {}
188
160
189
161
REGISTER_NET (async_scheduling, AsyncSchedulingNet);
190
162
0 commit comments