Skip to content

Commit

Permalink
Merge pull request #35 from ROCm/charlifu/update_base_image_with_newe…
Browse files Browse the repository at this point in the history
…r_pytorch

Update base docker image with Pytorch 2.3
  • Loading branch information
charlifu authored Jun 5, 2024
2 parents 69ce080 + a6af475 commit 68cdb95
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.rocm
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# default base image
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.1_ubuntu20.04_py3.9_pytorch_release-2.1.2"
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.1_ubuntu20.04_py3.9_pytorch_staging"

ARG COMMON_WORKDIR=/app

Expand Down
4 changes: 2 additions & 2 deletions csrc/quantization/fp8/gemm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b, torch::Tensor& scaleA
auto d_scaleD = scaleD.data_ptr();

auto handle = at::cuda::getCurrentCUDABlasLtHandle();
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
auto stream = at::cuda::getCurrentCUDAStream();

hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(0);
Expand Down Expand Up @@ -218,7 +218,7 @@ torch::Tensor fp8_gemm_16(
auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr();

auto handle = at::cuda::getCurrentCUDABlasLtHandle();
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
auto stream = at::cuda::getCurrentCUDAStream();

hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(0);
Expand Down

0 comments on commit 68cdb95

Please sign in to comment.