Skip to content

Commit 75c9894

Browse files
committed
numerous hipblaslt related fixes & fp8 buffer_comparator fix
1 parent 8f61c78 commit 75c9894

File tree

7 files changed

+99
-76
lines changed

7 files changed

+99
-76
lines changed

tensorflow/core/kernels/matmul_op_fused.cc

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -597,32 +597,31 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {
597597
epilog_op};
598598
absl::Mutex* pmu;
599599
auto plan_and_algorithms_or =
600-
PlanAndAlgorithms::GetOrCreate(stream, matmul_params, &pmu);
600+
BlasLtMatmulPlanCache::GetOrCreate(stream, matmul_params, &pmu);
601601
OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
602602
absl::MutexLock lock(pmu);
603-
const auto* plan_and_algorithms = std::move(plan_and_algorithms_or).value();
604-
const auto& algorithms = plan_and_algorithms->algorithms;
605-
OP_REQUIRES(context, algorithms.size() > 0,
603+
const auto& entry = *plan_and_algorithms_or.value();
604+
OP_REQUIRES(context, entry.algorithms.size() > 0,
606605
errors::InvalidArgument("No matmul algorithm returned!"));
607606

608607
auto launch_func = [&](BlasScratchAllocator& scratch_allocator,
609608
size_t alg_idx,
610609
se::blas::ProfileResult* profile_result) {
611-
return plan_and_algorithms->ExecuteOnStream(stream, a_ptr, b_ptr, c_ptr,
612-
alg_idx, scratch_allocator, bias_ptr,
613-
profile_result);
610+
return BlasLtMatmulPlanCache::ExecuteOnStream(
611+
stream, entry, a_ptr, b_ptr, c_ptr, alg_idx,
612+
scratch_allocator, bias_ptr, profile_result);
614613
};
615614

616615
size_t alg_idx = 0;
617616
if (use_autotune) {
618617
auto algorithm_config =
619-
AutotuneMatmul(algorithms, matmul_params, context, launch_func);
618+
AutotuneMatmul(entry.algorithms, matmul_params, context, launch_func);
620619

621620
alg_idx = algorithm_config.algorithm();
622621
}
623622

624623
OP_REQUIRES_OK(context, launch_func(scratch_allocator, alg_idx, nullptr));
625-
#endif
624+
#endif // GOOGLE_CUDA || TF_HIPBLASLT
626625
}
627626
};
628627

tensorflow/core/kernels/matmul_op_impl.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
637637
std::optional<int> max_algorithm_count;
638638
if (!use_autotune) max_algorithm_count = 1;
639639
absl::Mutex* pmu = nullptr;
640-
auto plan_and_algorithms_or = PlanAndAlgorithms::GetOrCreate(
640+
auto plan_and_algorithms_or = BlasLtMatmulPlanCache::GetOrCreate(
641641
stream, matmul_params, &pmu, max_algorithm_count);
642642
OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
643643
absl::MutexLock lock(pmu);
@@ -660,8 +660,9 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
660660
// scratch space is deallocated between runs.
661661
BlasScratchAllocator scratch_allocator(context, max_scratch_size);
662662
Status cublas_launch_status =
663-
plan_and_algorithms->ExecuteOnStream(stream, *a_ptrs[0],
664-
*b_ptrs[0], *c_ptrs[0], i, scratch_allocator,
663+
BlasLtMatmulPlanCache::ExecuteOnStream(stream,
664+
*plan_and_algorithms,
665+
*a_ptrs[0], *b_ptrs[0], *c_ptrs[0], i, scratch_allocator,
665666
se::DeviceMemoryBase{}, &profile_result);
666667

667668
VLOG(4) << " Autotune algorithm " << i
@@ -702,8 +703,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
702703

703704
OP_REQUIRES_OK(
704705
context,
705-
plan_and_algorithms->ExecuteOnStream(stream, *a_ptrs[0], *b_ptrs[0],
706-
*c_ptrs[0], algorithm_idx, scratch_allocator));
706+
BlasLtMatmulPlanCache::ExecuteOnStream(stream,
707+
*plan_and_algorithms,
708+
*a_ptrs[0], *b_ptrs[0], *c_ptrs[0],
709+
algorithm_idx, scratch_allocator, se::DeviceMemoryBase{}));
707710
} else { // requires mixed broadcasting
708711
const std::vector<int64_t>& a_batch_indices = bcast.x_batch_indices();
709712
const std::vector<int64_t>& b_batch_indices = bcast.y_batch_indices();

tensorflow/core/kernels/matmul_util.cc

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
#include <optional>
1818
#include <string>
19+
#include <deque>
1920
#include <utility>
2021

2122
#include "xla/status_macros.h"
@@ -24,6 +25,8 @@ limitations under the License.
2425
#include "tensorflow/core/platform/tensor_float_32_utils.h"
2526
#include "tensorflow/core/util/env_var.h"
2627
#include "tensorflow/core/util/matmul_autotune.h"
28+
#include "xla/stream_executor/stream.h"
29+
#include "xla/stream_executor/stream_executor.h"
2730

2831
namespace tensorflow {
2932

@@ -44,33 +47,13 @@ int64_t GetWorkspaceLimit(int64_t default_value_in_bytes) {
4447
return default_value_in_bytes;
4548
}
4649

47-
std::string BlasLtMatmulPlanParams::ToString() const {
48-
return ""; // TODO
49-
}
50-
5150
bool BlasLtMatmulPlanParams::operator==(
5251
const BlasLtMatmulPlanParams& other) const {
5352
return internal::AsTuple(*this) == internal::AsTuple(other);
5453
}
5554

5655
namespace {
5756

58-
// Thread-safe map from matmul parameters to their corresponding plan and
59-
// algorithms.
60-
struct BlasLtMatmulPlanMap {
61-
absl::Mutex mu;
62-
63-
template <class... Args>
64-
auto emplace(Args&&... args) {
65-
absl::MutexLock lock(&mu);
66-
return map_.emplace(std::forward<Args>(args)...);
67-
}
68-
69-
private:
70-
absl::node_hash_map<BlasLtMatmulPlanParams, PlanAndAlgorithms> map_
71-
ABSL_GUARDED_BY(mu);
72-
};
73-
7457
int MatmulMaxAutotuneAlgorithmCount() {
7558
int64_t value;
7659
Status status =
@@ -110,19 +93,31 @@ StatusOr<se::blas::ComputationType> GetBlasComputationType(
11093

11194
} // namespace
11295

113-
/* static */ StatusOr<const PlanAndAlgorithms*> PlanAndAlgorithms::GetOrCreate(
96+
/* static */ BlasLtMatmulPlanCache& BlasLtMatmulPlanCache::i(se::Stream *stream) {
97+
static absl::Mutex m(absl::kConstInit);
98+
// Each GPU gets different cache instance
99+
static std::deque< BlasLtMatmulPlanCache > meta(8);
100+
absl::MutexLock lock(&m);
101+
size_t dev_id = stream->parent()->device_ordinal();
102+
if (dev_id >= meta.size()) meta.resize(dev_id + 1);
103+
return meta[dev_id];
104+
}
105+
106+
/* static */ auto BlasLtMatmulPlanCache::GetOrCreate(
114107
se::Stream* stream, const BlasLtMatmulPlanParams& params,
115-
absl::Mutex** ppmu, std::optional<int> max_algorithm_count) {
108+
absl::Mutex** ppmu, std::optional<int> max_algorithm_count) -> StatusOr<const Entry *>{
116109
static const int64_t max_scratch_size =
117110
GetWorkspaceLimit(1LL << 32); // 4GB by default
118111
static const int64_t max_autotune_algorithm_count =
119112
MatmulMaxAutotuneAlgorithmCount();
120113

121114
if (!max_algorithm_count) max_algorithm_count = max_autotune_algorithm_count;
122115

123-
static BlasLtMatmulPlanMap plan_map;
116+
auto& self = BlasLtMatmulPlanCache::i(stream);
124117

125-
auto [ptr, inserted] = plan_map.emplace(params, PlanAndAlgorithms{});
118+
absl::MutexLock lock(self.mutex_.get());
119+
auto [ptr, inserted] = self.map_.emplace(params, Entry{});
120+
auto& entry = ptr->second;
126121
if (inserted) {
127122
TF_ASSIGN_OR_RETURN(auto xlatype,
128123
se::gpu::AsXlaPrimitiveType(params.dtype));
@@ -171,32 +166,28 @@ StatusOr<se::blas::ComputationType> GetBlasComputationType(
171166
.compute_type = computation_type,
172167
};
173168

174-
TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan(
169+
TF_ASSIGN_OR_RETURN(entry.plan, se::gpu::BlasLt::GetMatmulPlan(
175170
stream, cfg, params.epilogue));
176171

177172
TF_ASSIGN_OR_RETURN(
178-
auto algorithms,
179-
plan->GetAlgorithms(*max_algorithm_count, max_scratch_size));
180-
181-
ptr->second = {std::move(plan), std::move(algorithms)};
173+
entry.algorithms,
174+
entry.plan->GetAlgorithms(*max_algorithm_count, max_scratch_size));
182175
}
183-
*ppmu = &plan_map.mu;
184-
return &ptr->second;
176+
*ppmu = self.mutex_.get();
177+
return &entry;
185178
}
186179

187-
Status PlanAndAlgorithms::ExecuteOnStream(se::Stream* stream,
180+
/*static */ Status BlasLtMatmulPlanCache::ExecuteOnStream(se::Stream* stream,
181+
const Entry& entry,
188182
const se::DeviceMemoryBase& a,
189183
const se::DeviceMemoryBase& b,
190184
se::DeviceMemoryBase& c,
191185
size_t algorithm_idx,
192186
se::ScratchAllocator& scratch_allocator,
193187
const se::DeviceMemoryBase& bias,
194-
se::blas::ProfileResult* profile_result) const {
188+
se::blas::ProfileResult* profile_result) {
195189

196-
if(!plan || algorithm_idx >= algorithms.size()) {
197-
return errors::Internal("MatmulPlan or algorithms are not initialized!");
198-
}
199-
return plan->ExecuteOnStream(
190+
return entry.plan->ExecuteOnStream(
200191
stream, a, b, c, c,
201192
bias, // bias_buffer
202193
se::DeviceMemoryBase{}, // aux_buffer
@@ -205,9 +196,8 @@ Status PlanAndAlgorithms::ExecuteOnStream(se::Stream* stream,
205196
se::DeviceMemoryBase{}, // c_scale_buffer
206197
se::DeviceMemoryBase{}, // d_scale_buffer
207198
se::DeviceMemoryBase{}, // d_amax_buffer
208-
algorithms[algorithm_idx],
209-
std::nullopt, // workspace
210-
&scratch_allocator,
199+
entry.algorithms[algorithm_idx],
200+
scratch_allocator,
211201
profile_result);
212202
}
213203

tensorflow/core/kernels/matmul_util.h

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ namespace tensorflow {
3535
int64_t GetWorkspaceLimit(int64_t default_value_in_bytes);
3636

3737
struct BlasLtMatmulPlanParams {
38-
std::string ToString() const;
38+
39+
std::string ToString() const { return "NOP"; }
3940
bool operator==(const BlasLtMatmulPlanParams& other) const;
4041

4142
se::blas::DataType dtype;
@@ -50,26 +51,6 @@ struct BlasLtMatmulPlanParams {
5051
se::gpu::BlasLt::Epilogue epilogue = se::gpu::BlasLt::Epilogue::kDefault;
5152
};
5253

53-
struct PlanAndAlgorithms {
54-
55-
static StatusOr<const PlanAndAlgorithms*> GetOrCreate(
56-
se::Stream* stream, const BlasLtMatmulPlanParams& params, absl::Mutex** pmu,
57-
std::optional<int> max_algorithm_count = std::nullopt
58-
);
59-
60-
Status ExecuteOnStream(se::Stream* stream,
61-
const se::DeviceMemoryBase& a,
62-
const se::DeviceMemoryBase& b,
63-
se::DeviceMemoryBase& c,
64-
size_t algorithm_idx,
65-
se::ScratchAllocator& scratch_allocator,
66-
const se::DeviceMemoryBase& bias = se::DeviceMemoryBase{},
67-
se::blas::ProfileResult* profile_result = nullptr) const;
68-
69-
se::gpu::BlasLt::MatmulPlanPtr plan;
70-
std::vector<se::gpu::BlasLt::MatmulAlgorithm> algorithms;
71-
};
72-
7354
namespace internal {
7455

7556
inline auto AsTuple(const BlasLtMatmulPlanParams& p) {
@@ -85,6 +66,40 @@ H AbslHashValue(H h, const BlasLtMatmulPlanParams& params) {
8566
return H::combine(std::move(h), internal::AsTuple(params));
8667
}
8768

69+
struct BlasLtMatmulPlanCache {
70+
struct Entry {
71+
se::gpu::BlasLt::MatmulPlanPtr plan;
72+
std::vector< se::gpu::BlasLt::MatmulAlgorithm > algorithms;
73+
};
74+
75+
static StatusOr<const Entry *> GetOrCreate(
76+
se::Stream* stream, const BlasLtMatmulPlanParams& params, absl::Mutex** pmu,
77+
std::optional<int> max_algorithm_count = std::nullopt
78+
);
79+
80+
// helper function for plan execution
81+
static Status ExecuteOnStream(se::Stream* stream,
82+
const Entry& entry,
83+
const se::DeviceMemoryBase& a,
84+
const se::DeviceMemoryBase& b,
85+
se::DeviceMemoryBase& c,
86+
size_t algorithm_idx,
87+
se::ScratchAllocator& scratch_allocator,
88+
const se::DeviceMemoryBase& bias,
89+
se::blas::ProfileResult* profile_result = nullptr);
90+
91+
BlasLtMatmulPlanCache() : mutex_(new absl::Mutex) {
92+
}
93+
94+
private:
95+
static BlasLtMatmulPlanCache& i(se::Stream *stream);
96+
97+
std::unique_ptr<absl::Mutex> mutex_;
98+
absl::node_hash_map<BlasLtMatmulPlanParams, Entry> map_
99+
ABSL_GUARDED_BY(mutex_);
100+
101+
}; // BlasLtMatmulPlanCache
102+
88103
} // namespace tensorflow
89104

90105
#endif // GOOGLE_CUDA || TF_HIPBLASLT

third_party/xla/xla/service/gpu/buffer_comparator.cu.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,16 @@ __global__ void xla_fp8_e5m2_comparison(__nv_fp8_storage_t* buffer_a,
103103
#endif // GOOGLE_CUDA
104104

105105
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
106+
106107
__global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a,
107108
__hip_fp8_storage_t* buffer_b,
108109
float rel_error_threshold,
109110
uint64_t buffer_length,
110111
int* mismatch_count) {
112+
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
113+
// NOTE: according to amd_hip_fp8.h, GFX1200 and GFX1201 support ocp __hip_fp8_e4m3
114+
// but not __hip_fp8_e4m3_fnuz
115+
111116
int idx = threadIdx.x + blockIdx.x * blockDim.x;
112117
if (idx >= buffer_length) return;
113118
__hip_fp8_e4m3_fnuz elem_a_fp8, elem_b_fp8;
@@ -123,13 +128,18 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a,
123128

124129
if (rel_error > rel_error_threshold || isnan(rel_error))
125130
atomicAdd(mismatch_count, 1);
131+
#else
132+
// on unsupported architectures, this should not / cannot be used!
133+
atomicAdd(mismatch_count, 1);
134+
#endif
126135
}
127136

128137
__global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a,
129138
__hip_fp8_storage_t* buffer_b,
130139
float rel_error_threshold,
131140
uint64_t buffer_length,
132141
int* mismatch_count) {
142+
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
133143
int idx = threadIdx.x + blockIdx.x * blockDim.x;
134144
if (idx >= buffer_length) return;
135145
__hip_fp8_e5m2_fnuz elem_a_fp8, elem_b_fp8;
@@ -145,7 +155,12 @@ __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a,
145155

146156
if (rel_error > rel_error_threshold || isnan(rel_error))
147157
atomicAdd(mismatch_count, 1);
158+
#else
159+
// on unsupported architectures, this should not / cannot be used!
160+
atomicAdd(mismatch_count, 1);
161+
#endif
148162
}
163+
149164
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
150165

151166
__global__ void xla_fp16_comparison(__half* buffer_a, __half* buffer_b,

third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,4 +618,4 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream(
618618

619619
} // namespace stream_executor
620620

621-
#endif // TF_HIPBLASLT
621+
#endif // TF_HIPBLASLT

third_party/xla/xla/stream_executor/rocm/hipblaslt_wrapper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#define XLA_STREAM_EXECUTOR_ROCM_HIPBLASLT_WRAPPER_H_
1919

2020
#define __HIP_DISABLE_CPP_FUNCTIONS__
21+
#define LEGACY_HIPBLAS_DIRECT
2122

2223
#include "rocm/rocm_config.h"
2324

0 commit comments

Comments
 (0)