@@ -16,6 +16,7 @@ limitations under the License.
16
16
17
17
#include < optional>
18
18
#include < string>
19
+ #include < deque>
19
20
#include < utility>
20
21
21
22
#include " xla/status_macros.h"
@@ -24,6 +25,8 @@ limitations under the License.
24
25
#include " tensorflow/core/platform/tensor_float_32_utils.h"
25
26
#include " tensorflow/core/util/env_var.h"
26
27
#include " tensorflow/core/util/matmul_autotune.h"
28
+ #include " xla/stream_executor/stream.h"
29
+ #include " xla/stream_executor/stream_executor.h"
27
30
28
31
namespace tensorflow {
29
32
@@ -44,33 +47,13 @@ int64_t GetWorkspaceLimit(int64_t default_value_in_bytes) {
44
47
return default_value_in_bytes;
45
48
}
46
49
47
- std::string BlasLtMatmulPlanParams::ToString () const {
48
- return " " ; // TODO
49
- }
50
-
51
50
bool BlasLtMatmulPlanParams::operator ==(
52
51
const BlasLtMatmulPlanParams& other) const {
53
52
return internal::AsTuple (*this ) == internal::AsTuple (other);
54
53
}
55
54
56
55
namespace {
57
56
58
- // Thread-safe map from matmul parameters to their corresponding plan and
59
- // algorithms.
60
- struct BlasLtMatmulPlanMap {
61
- absl::Mutex mu;
62
-
63
- template <class ... Args>
64
- auto emplace (Args&&... args) {
65
- absl::MutexLock lock (&mu);
66
- return map_.emplace (std::forward<Args>(args)...);
67
- }
68
-
69
- private:
70
- absl::node_hash_map<BlasLtMatmulPlanParams, PlanAndAlgorithms> map_
71
- ABSL_GUARDED_BY (mu);
72
- };
73
-
74
57
int MatmulMaxAutotuneAlgorithmCount () {
75
58
int64_t value;
76
59
Status status =
@@ -110,19 +93,31 @@ StatusOr<se::blas::ComputationType> GetBlasComputationType(
110
93
111
94
} // namespace
112
95
113
- /* static */ StatusOr<const PlanAndAlgorithms*> PlanAndAlgorithms::GetOrCreate (
96
+ /* static */ BlasLtMatmulPlanCache& BlasLtMatmulPlanCache::i (se::Stream *stream) {
97
+ static absl::Mutex m (absl::kConstInit );
98
+ // Each GPU gets different cache instance
99
+ static std::deque< BlasLtMatmulPlanCache > meta (8 );
100
+ absl::MutexLock lock (&m);
101
+ size_t dev_id = stream->parent ()->device_ordinal ();
102
+ if (dev_id >= meta.size ()) meta.resize (dev_id + 1 );
103
+ return meta[dev_id];
104
+ }
105
+
106
+ /* static */ auto BlasLtMatmulPlanCache::GetOrCreate (
114
107
se::Stream* stream, const BlasLtMatmulPlanParams& params,
115
- absl::Mutex** ppmu, std::optional<int > max_algorithm_count) {
108
+ absl::Mutex** ppmu, std::optional<int > max_algorithm_count) -> StatusOr< const Entry *> {
116
109
static const int64_t max_scratch_size =
117
110
GetWorkspaceLimit (1LL << 32 ); // 4GB by default
118
111
static const int64_t max_autotune_algorithm_count =
119
112
MatmulMaxAutotuneAlgorithmCount ();
120
113
121
114
if (!max_algorithm_count) max_algorithm_count = max_autotune_algorithm_count;
122
115
123
- static BlasLtMatmulPlanMap plan_map ;
116
+ auto & self = BlasLtMatmulPlanCache::i (stream) ;
124
117
125
- auto [ptr, inserted] = plan_map.emplace (params, PlanAndAlgorithms{});
118
+ absl::MutexLock lock (self.mutex_ .get ());
119
+ auto [ptr, inserted] = self.map_ .emplace (params, Entry{});
120
+ auto & entry = ptr->second ;
126
121
if (inserted) {
127
122
TF_ASSIGN_OR_RETURN (auto xlatype,
128
123
se::gpu::AsXlaPrimitiveType (params.dtype ));
@@ -171,32 +166,28 @@ StatusOr<se::blas::ComputationType> GetBlasComputationType(
171
166
.compute_type = computation_type,
172
167
};
173
168
174
- TF_ASSIGN_OR_RETURN (auto plan, se::gpu::BlasLt::GetMatmulPlan (
169
+ TF_ASSIGN_OR_RETURN (entry. plan , se::gpu::BlasLt::GetMatmulPlan (
175
170
stream, cfg, params.epilogue ));
176
171
177
172
TF_ASSIGN_OR_RETURN (
178
- auto algorithms,
179
- plan->GetAlgorithms (*max_algorithm_count, max_scratch_size));
180
-
181
- ptr->second = {std::move (plan), std::move (algorithms)};
173
+ entry.algorithms ,
174
+ entry.plan ->GetAlgorithms (*max_algorithm_count, max_scratch_size));
182
175
}
183
- *ppmu = &plan_map. mu ;
184
- return &ptr-> second ;
176
+ *ppmu = self. mutex_ . get () ;
177
+ return &entry ;
185
178
}
186
179
187
- Status PlanAndAlgorithms::ExecuteOnStream (se::Stream* stream,
180
+ /* static */ Status BlasLtMatmulPlanCache::ExecuteOnStream (se::Stream* stream,
181
+ const Entry& entry,
188
182
const se::DeviceMemoryBase& a,
189
183
const se::DeviceMemoryBase& b,
190
184
se::DeviceMemoryBase& c,
191
185
size_t algorithm_idx,
192
186
se::ScratchAllocator& scratch_allocator,
193
187
const se::DeviceMemoryBase& bias,
194
- se::blas::ProfileResult* profile_result) const {
188
+ se::blas::ProfileResult* profile_result) {
195
189
196
- if (!plan || algorithm_idx >= algorithms.size ()) {
197
- return errors::Internal (" MatmulPlan or algorithms are not initialized!" );
198
- }
199
- return plan->ExecuteOnStream (
190
+ return entry.plan ->ExecuteOnStream (
200
191
stream, a, b, c, c,
201
192
bias, // bias_buffer
202
193
se::DeviceMemoryBase{}, // aux_buffer
@@ -205,9 +196,8 @@ Status PlanAndAlgorithms::ExecuteOnStream(se::Stream* stream,
205
196
se::DeviceMemoryBase{}, // c_scale_buffer
206
197
se::DeviceMemoryBase{}, // d_scale_buffer
207
198
se::DeviceMemoryBase{}, // d_amax_buffer
208
- algorithms[algorithm_idx],
209
- std::nullopt, // workspace
210
- &scratch_allocator,
199
+ entry.algorithms [algorithm_idx],
200
+ scratch_allocator,
211
201
profile_result);
212
202
}
213
203
0 commit comments