-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] MatX is around x15 slower than CuPy for the same task #688
Comments
Hi @HugoPhibbs , this is very interesting an unexpected. We'll take a look at the profile and get back to you. |
Can the batches all run in parallel and get the same answer? We generally
suggest removing batch loops like this and just send in the entire tensor.
It's simpler and faster.
…On Fri, Aug 2, 2024, 10:38 PM Cliff Burdick ***@***.***> wrote:
Hi @HugoPhibbs <https://github.com/HugoPhibbs> , this is very interesting
an unexpected. We'll take a look at the profile and get back to you.
—
Reply to this email directly, view it on GitHub
<#688 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABSFS4T5YZ7UXGLE4H3PGFLZPRNFLAVCNFSM6AAAAABL5R35MGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENRWGM3DSNJSG4>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Hi @luitjens I'm using batches because otherwise, my GPU quickly runs out of memory (I tried no batching with CuPy and this was the result). Batching is used to control the memory usage of intermediary tensors. I intend in the future to tune the batch size to produce optimal memory usage of the GPU, but right now, I'm focused on getting an MVP. |
Are you timing the allocation and page faults of managed memory as part of
the execution time? If you switch to cuda memory instead of managed does
the perf issue go away?
…On Fri, Aug 2, 2024, 11:08 PM Hugo Phibbs ***@***.***> wrote:
Hi @luitjens <https://github.com/luitjens>
I'm using batches because otherwise, my GPU quickly runs out of memory (I
tried no batching with CuPy and this was the result). Batching is used to
control the memory usage of intermediary tensors.
I intend in the future to tune the batch size to produce optimal memory
usage of the GPU, but right now, I'm focused on getting an MVP.
—
Reply to this email directly, view it on GitHub
<#688 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABSFS4XSCA6HXWE75C2WXRLZPRQWVAVCNFSM6AAAAABL5R35MGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENRWGM3TMNJWGU>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Hi @luitjens, thx for getting back to me. I'm timing the complete function runtime - as in how long it takes to run the function start to finish. The timing looks a bit like this: TEST_F(TestFindingDistances, TestLargeInputMatX) {
int k = 5;
int n = 70000;
int m = 50;
int D = 1024;
int d = 784;
auto A = tu::createMockAMatrixMatX(n, k, D);
auto B = tu::createMockBMatrixMatX(n, m, D);
auto X = tu::createMockMnistDatasetMatX(n, d);
cudaDeviceSynchronize(); // Possibly not necessary?
tu::Time start = tu::timeNow();
auto distances = GsDBSCAN::findDistancesMatX(X, A, B, 1.2, 250);
cudaDeviceSynchronize();
tu::printDurationSinceStart(start);
printf("%lld %lld", distances.Shape()[0], distances.Shape()[1]);
ASSERT_TRUE(distances.Shape()[0] == n);
ASSERT_TRUE(distances.Shape()[1] == 2*k*m);
} As for memory options, I changed the memory space of all the tensors to inline auto createMockAMatrixMatX(int n = 70000, int k = 2, int D = 1024) {
auto A = matx::make_tensor<float>({n, 2*k}, matx::MATX_DEVICE_MEMORY);
auto A_i = matx::make_tensor<int32_t>({n, 2*k}, matx::MATX_DEVICE_MEMORY);
int a = 2 * (D - 1);
(A = matx::random<float>({n, 2*k}, matx::UNIFORM, 0, a)).run();
(A_i = matx::as_type<int32_t>(A)).run();
return A_i;
} |
Hi again, I've done some more testing, and I've found that the cuda synchronise step takes the lion's share of the runtime. I added some hacky profiling to the function that looks like this: matx::tensor_t<matx::matxFp16, 2> GsDBSCAN::findDistancesMatX(matx::tensor_t<matx::matxFp16, 2> &X_t, matx::tensor_t<int, 2> &A_t, matx::tensor_t<int, 2> &B_t, float alpha, int batchSize) {
const int k = A_t.Shape()[1] / 2;
const int m = B_t.Shape()[1];
const int n = X_t.Shape()[0];
const int d = X_t.Shape()[1];
int D = B_t.Shape()[0] / 2;
batchSize = (batchSize != -1) ? batchSize : GsDBSCAN::findDistanceBatchSize(alpha, n, d, k, m);
auto AFlat_t = matx::flatten(A_t);
auto distances_t = matx::make_tensor<matx::matxFp16>({n, 2*k*m}, matx::MATX_DEVICE_MEMORY);
int j = 0;
std::vector<double> times;
auto start_all = std::chrono::high_resolution_clock::now();
for (int i = 0; i < n; i += batchSize) {
auto start = std::chrono::high_resolution_clock::now();
int maxBatchIdx = i + batchSize - 1; // Index within X along the ROWS
auto XSubset_t_op = matx::slice(X_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd});
auto ABatchFlat_t_op = matx::slice(AFlat_t, {i * 2 * k}, {(maxBatchIdx + 1) * 2 * k});
auto BBatch_t_op = matx::remap<0>(B_t, ABatchFlat_t_op);
auto XBatch_t_op = matx::remap<0>(X_t, matx::flatten(BBatch_t_op));
auto XBatchReshaped_t_op = matx::reshape(XBatch_t_op, {batchSize, 2*k*m, d});
auto XSubsetReshaped_t_op = matx::reshape(XSubset_t_op, {batchSize, 1, d});
auto YBatch_t_op = (XBatchReshaped_t_op - matx::repmat(XSubsetReshaped_t_op, {1, 2*k*m, 1})); // Repmat is a workaround for minusing naively incompatibhle tensor shapes
auto YBatch_t_norm_op = matx::vector_norm(YBatch_t_op, {2}, matx::NormOrder::L2);
(matx::slice(distances_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd}) = YBatch_t_norm_op).run();
// Record end time
auto end = std::chrono::high_resolution_clock::now();
// Calculate the duration
std::chrono::duration<double> duration = end - start;
// Cast to double and store in array
times.push_back(duration.count());
}
auto start_sync = std::chrono::high_resolution_clock::now();
cudaDeviceSynchronize();
// Record end time
auto end_sync = std::chrono::high_resolution_clock::now();
// Calculate the duration
std::chrono::duration<double> duration_sync = end_sync - start_sync;
// Output the duration
std::cout << "Time taken: " << duration_sync.count() << " seconds" << std::endl;
for (const auto& element : times) {
std::cout << element << std::endl;
}
// Record end time
auto end_all = std::chrono::high_resolution_clock::now();
// Calculate the duration
std::chrono::duration<double> duration = end_all - start_all;
// Output the duration
std::cout << "Time taken: " << duration.count() << " seconds" << std::endl;
return distances_t;
} Which produces the output:
Has this got something to do with the fact that MatX looks to have an async execution style? I.e. adding a bunch of async operations to queue on GPU may produce a large bottleneck effect? - Just an idea |
Hi, can you please provide fully buildable/runnable example in both matx and python that we can use to compare? Generally speaking you don't want to include allocation time in your timings as you want to allocate once upfront and reuse. |
alternatively if you cannot easily create us a standalone reproducer can you share an nsys profile of both python and matx with us? |
Ok thx, pls see the gist: https://gist.github.com/HugoPhibbs/a2ce2c75b70c6737f1094f32b15af3ea It contains source files to run it, along with an nsys profile |
I recreated your repro as an example and had to make a few modifications to get it to build. Once I did that I ran on H100 and I see this output: Total Time taken: 0.0242754 seconds Unfortunately I was not able to view your profile as it says it is corrupt. Could you get a fresh profile, put it in a zip and upload it to your example? |
Hugo I created a repro with some build fixes here: https://github.com/NVIDIA/MatX/tree/688-repro From your build directory: Can you verify that the issue still reproduces? |
On L40 i see similar performance: Total Time taken: 0.0205341 seconds |
Ok thanks. Honestly I'm a little bit skeptical that it would take just a fraction of a second. But yes, the error still reproduces on my machine: make repro
[ 50%] Building CUDA object examples/CMakeFiles/repro.dir/repro.cu.o
[100%] Linking CUDA executable repro
[100%] Built target repro
./examples/repro
Sync Time taken: 14.3653 seconds
0.00239233
2.32e-05
....
1.776e-05
Total Time taken: 14.3747 seconds
Total Time taken (again): 14.375 seconds
70000 500 Please see this zip for the profile test_profile.zip - may have been an encoding issue. I guess now would be a good time to show you my environment: nvidia-smi
Wed Aug 7 09:38:27 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.28.03 Driver Version: 560.28.03 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3090 Off | 00000000:01:00.0 On | N/A |
| 53% 45C P5 63W / 390W | 1022MiB / 24576MiB | 47% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA GeForce RTX 3090 Off | 00000000:4A:00.0 On | N/A |
| 0% 48C P8 49W / 390W | 115MiB / 24576MiB | 25% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------
|
Thank you for the profile. When I inspect the profiles the matx generated kernels seem inline with hardware (2ms on H100 and 6.5ms on 3090). However, the reduction kernel seems way off. We use cub for this kernel so perhaps there is something going wrong in cub. We will investigate this. As a work around can you try to materialize the inputs to the reduction kernel into a memory backed tensor then compute the vector norm on the memory backed tensor: https://github.com/NVIDIA/MatX/blob/688-repro/examples/repro.cu#L96 |
can you also get me an ncu profile with this command on your system:
Then zip up 309.ncu-rep and attach that too. |
updated ncu instruction above |
Also can you try updating your toolkit? You currently have: Cuda 11.8. I'd suggest going to 12.5. |
@HugoPhibbs I ran this on both an A100 and 3090. Here are the results:
3090:
This is CUDA 12.5. I will try 11.8 and report back. |
@HugoPhibbs on my nsys capture I see 32 registers per thread whereas @luitjens pointed out you had 128. Here is our compilation line:
Can you send what yours looks like? If you're importing the |
@HugoPhibbs I was able to reproduce your issue on CUDA 11.8 with everything else the same:
Is it possible for you to update? This may be an issue where nvcc had trouble with register reuse in this case causing poor occupancy. |
thx @cliffburdick and @luitjens @luitjens re As on the front of upgrading CUDA, I upgraded to 12.5, it runs ok, but now my tests are broken 🙃. Just to make sure, when I do E.g. my simple tests look a bit like: auto distances_t = GsDBSCAN::findDistancesMatX(X_t_16, A_t, B_t);
cudaDeviceSynchronize();
matx::matxFp16 *distances_ptr = distances_t.Data();
matx::matxFp16 expected_squared[] = {
11, 5, 14, 11, 0, 5,
9, 0, 11, 0, 14, 11,
5, 0, 5, 5, 8, 14,
9, 5, 0, 0, 9, 5,
9, 6, 5, 5, 0, 6
};
for (int i = 0; i < 5*6; i++) {
ASSERT_NEAR(std::sqrt(expected_squared[i]), distances_ptr[i], 1e-3); // distances is full of zeros with 12.5 but actually full in 11.8
} Do you guys know a reason why this may be? |
First thing I'd do is drop this macro call I to your code, specifically
after syncs. This will verify that no cuda errors occurred.
https://gist.github.com/jefflarkin/5390993
…On Wed, Aug 7, 2024, 5:09 PM Hugo Phibbs ***@***.***> wrote:
thx @cliffburdick <https://github.com/cliffburdick> and @luitjens
<https://github.com/luitjens>
@luitjens <https://github.com/luitjens> re ncu, currently waiting for
admin permissions to run sudo ncu ..., I'll send results once I can.
As on the front of upgrading CUDA, I upgraded to 12.5, it runs ok, but now
my tests are broken 🙃.
Just to make sure, when I do cudaDeviceSynchronize() this makes sure that
any pending operations on the GPU are done right? When I upgrade to 12.5,
the returned distances_t tensor is now just empty (full of zeros) - where
as with 11.8 it was full of values.
E.g. my simple tests look a bit like:
auto distances_t = GsDBSCAN::findDistancesMatX(X_t_16, A_t, B_t);
cudaDeviceSynchronize();
matx::matxFp16 *distances_ptr = distances_t.Data();
matx::matxFp16 expected_squared[] = {
11, 5, 14, 11, 0, 5,
9, 0, 11, 0, 14, 11,
5, 0, 5, 5, 8, 14,
9, 5, 0, 0, 9, 5,
9, 6, 5, 5, 0, 6
};
for (int i = 0; i < 5*6; i++) {
ASSERT_NEAR(std::sqrt(expected_squared[i]), distances_ptr[i], 1e-3); // distances is full of zeros with 12.5 but actually full in 11.8
}
Do you guys know a reason why this may be?
—
Reply to this email directly, view it on GitHub
<#688 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABSFS4TDGUQAGEZ3PHROMVTZQKSLFAVCNFSM6AAAAABL5R35MGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENZUGQ4TKNRWGA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
@luitjens yep added the macro and no errors occur |
Ok, can you create a pr to modify the repro branch which assets that there
is an error? Then in the morning I will dig into it.
…On Wed, Aug 7, 2024, 5:34 PM Hugo Phibbs ***@***.***> wrote:
@luitjens <https://github.com/luitjens> yep added the macro and no errors
occur
—
Reply to this email directly, view it on GitHub
<#688 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABSFS4W5PD7M2GQ4BNK6FNDZQKVHZAVCNFSM6AAAAABL5R35MGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENZUGUYTQNJVGM>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
sure will do |
Can you also verify you are not trying to dereference a device pointer on
the host?
…On Wed, Aug 7, 2024, 5:35 PM Justin Luitjens ***@***.***> wrote:
Ok, can you create a pr to modify the repro branch which assets that there
is an error? Then in the morning I will dig into it.
On Wed, Aug 7, 2024, 5:34 PM Hugo Phibbs ***@***.***> wrote:
> @luitjens <https://github.com/luitjens> yep added the macro and no
> errors occur
>
> —
> Reply to this email directly, view it on GitHub
> <#688 (comment)>, or
> unsubscribe
> <https://github.com/notifications/unsubscribe-auth/ABSFS4W5PD7M2GQ4BNK6FNDZQKVHZAVCNFSM6AAAAABL5R35MGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENZUGUYTQNJVGM>
> .
> You are receiving this because you were mentioned.Message ID:
> ***@***.***>
>
|
Ok have checked. But don't think is the case since I'm using managed memory? I was getting seg fault when using device memory, so I changed the mem to managed and it worked (in cuda 11.8) |
Ok yes managed is fine. I will review in the morning.
…On Wed, Aug 7, 2024, 6:07 PM Hugo Phibbs ***@***.***> wrote:
Ok have checked. But don't think is the case since I'm using managed
memory?
I was getting seg fault when using device memory, so I changed the mem to
managed and it worked (in cuda 11.8)
—
Reply to this email directly, view it on GitHub
<#688 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABSFS4SEVBTK4XJIYPTBCRTZQKZFTAVCNFSM6AAAAABL5R35MGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENZUGU3TCOBYGA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
@HugoPhibbs we're still looking into it, but we can reproduce your issue. |
@luitjens @cliffburdick have tried a fresh rebuild and using floats instead of halves. Behaviour is more or less the same. I'm getting strange behaviour when I run it a This is the file I'm running: https://gist.github.com/HugoPhibbs/1bfd7180119040186b57b515dff4f69d That file was run from the I've commented out the print loop. Can you guys check if it has the same behavior for you? - i.e. if it's all zeros? |
re gtest being slower, I can confirm that this is instead probably something to do with my project. When I copy the If it's helpful my cmake looks like: cmake_minimum_required(VERSION 3.27)
project(DbscanCEOs LANGUAGES CXX CUDA C)
#enable_language(CUDA)
project(sDbscan)
set(CMAKE_CXX_STANDARD 17)
SET(CMAKE_CUDA_ARCHITECTURES 86)
set(CMAKE_CUDA_COMPILER "/usr/local/cuda-12.6/bin/nvcc") # Somehow CLion needs this here (smh)
#SET(CMAKE_C_COMPILER "/usr/bin/g++")
#add_definitions(-DINDEX_64_BIT)
#SET(CMAKE_BUILD_TYPE Debug)
find_package(Eigen3 3.3 REQUIRED NO_MODULE)
find_package(Boost 1.71 REQUIRED NO_MODULE)
# CCCL
include(cmake/CPM.cmake)
# This will automatically clone CCCL from GitHub and make the exported cmake targets available
CPMAddPackage(
NAME CCCL
GITHUB_REPOSITORY nvidia/cccl
GIT_TAG v2.4.0
)
find_package(OpenMP)
if (OPENMP_FOUND)
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
endif()
set(CUDA_TOOLKIT_ROOT_DIR $ENV{CUDA_HOME})
find_package(CUDAToolkit 12.6 REQUIRED)
#find_package(CUDAToolkit 11.8 REQUIRED)
# ArrayFire
find_package(ArrayFire REQUIRED)
# MatX https://github.com/NVIDIA/MatX
find_package(matx CONFIG REQUIRED)
include_directories(
${PROJECT_SOURCE_DIR}/lib/eigen-3.4.0
${CUDA_TOOLKIT_ROOT_DIR}/include
${gtest_SOURCE_DIR}/include
${gtest_SOURCE_DIR}
)
link_directories(
${CUDA_TOOLKIT_ROOT_DIR}/lib64
)
add_subdirectory(lib/googletest)
set_source_files_properties(
test/gsDBSCAN/GsDBSCANTest.cpp
test/gsDBSCAN/UtilsTest.cpp
test/gsDBSCAN/PreprocessingTest.cpp
test/gsDBSCAN/DistancesTest.cpp
test/gsDBSCAN/ClusteringTest.cpp
include/gsDBSCAN/clustering.h
test/repro2.cu
PROPERTIES LANGUAGE CUDA)
add_executable(${PROJECT_NAME}
src/main.cpp
src/Utilities.cpp
src/dbscan/sDbscan.cpp
src/fast_copy.c
src/fht/fht.c
include/gsDBSCAN/preprocessing.h
include/gsDBSCAN/utils.h
include/gsDBSCAN/distances.h
include/gsDBSCAN/clustering.h
)
add_executable(run_gs_dbscan_tests
test/gsDBSCAN/GsDBSCANTest.cpp
test/TestUtils.cpp
test/gsDBSCAN/UtilsTest.cpp
test/gsDBSCAN/PreprocessingTest.cpp
test/gsDBSCAN/DistancesTest.cpp
test/gsDBSCAN/ClusteringTest.cpp
)
add_executable(repro2
test/repro2.cu
)
target_link_libraries(run_gs_dbscan_tests PRIVATE CCCL::CCCL CUDA::cudart ArrayFire::afcuda matx::matx gtest gtest_main)
target_link_libraries(${PROJECT_NAME} PRIVATE CCCL::CCCL CUDA::cudart ArrayFire::afcuda Eigen3::Eigen matx::matx)
target_link_libraries(repro2 PRIVATE CCCL::CCCL matx::matx) |
I would issue a |
I retested this on the main branch of MatX. I'm now getting 13.8 seconds or so for both my repo and within the MatX repo for running I made the mistake of running This is what it looks like: From MatX repo (up to date with 8th August, so it should have the fix for the
From my repo
Would you be able to run Thanks v much |
Hey, sorry if I'm being a bit pestering, but just a quick nudge on this. Have you had a chance to reproduce the same results? Thx for all your help so far. |
Hi @HugoPhibbs , not pestering at all. Got caught up in other things but will take a look tomorrow. |
@HugoPhibbs I was able to reproduce it on the 3090. Will do some more investigating. |
Hi @HugoPhibbs, we tracked down the problem to register spills possibly caused by the iterator size going into CUB. When you materialize the operator into a tensor before calling the norm it improves the speed about 4x. That's still not as fast as cuPy, so we'd like to look into the remaining performance issue until it's faster than cuPy, but we wanted to make sure they're doing the same computation. Can you paste the entire cuPy example including the main function so we can run it standalone? |
@cliffburdick sure can do. Heres the CuPy code:
I'm quite confident that they are doing the same computation as they produce identical results for n=1000, and its just doing the same thing over and over again with batches anyway. |
I'm getting 40s on a 3090 in Python with your code. Is that what you expect? It seems a lot slower than what you said above. |
@cliffburdick that is very odd, not what I'd expect at all. I'm getting 0.73 seconds or so:
See this Gist containing updated print statements: https://gist.github.com/HugoPhibbs/ba3ae26c9ff09ea997ece53c9b856399 My CuPy package is |
You're right, it was a caching issue. I'm getting close to yours when the JIT has completed. |
Just so we're on the same page, what is the exact code fix you applied to materialise the operator into a tensor before calling the norm? |
Hi @HugoPhibbs, it was:
|
@HugoPhibbs we have a fix/suggestion that will get you to 2x faster than cuPy, but we still have some work to do to go faster. We should be able to commit the first patch soon. |
@HugoPhibbs can you please pull the latest commit and compile with |
Hi @cliffburdick thanks for the help. I tried the change (rebuilt and installed etc), it speed up the runtime certainly, however, it takes 5 seconds still (an improvement nonethess). I've since reimplemented the code in question with Torch. The runtime of torch is 0.79 seconds (very similar to CupY), while the runtime of matx is around 5 seconds Please see the attached profiles for running my overall algorithm - I ran it by running the code in question (its a part of an overall algorithm) with torch, and then with matx. The end-results of both runs are very similar. Except the MatX takes remarkably longer - the profile looks much different too. |
Thanks @HugoPhibbs , was this still a 3090? I was getting 0.5s when I tested on that card, but I will retry. |
yeah. Actually, could you send the code you're using? |
We are using the repro.cu from this branch: |
Sorry for the delay, I brought changes from #90bf114 into the latest from the main. And ran it, I had to disable the treat warnings as errors setting in the main cmake. The results are:
|
Hi @HugoPhibbs . I rebased 686-repro and pushed. After doing that and building with 32b support on a 3090, here are my results:
Are you running the example exactly as it is? |
mb, forgot to build with 32 bit. my results are:
These are different from yours - even so, your results are still about the same speed as CuPy? For debugging my
|
Hi @HugoPhibbs, yes, it seems to be about the same as pytorch. Previously I was comparing to your cupy results. We have further optimizations we can make, but I think before doing that we're better off looking at your python and reimplementing it versus trying to optimize something that may not represent the original problem. Are you able to post your pytorch code? |
Hi @cliffburdick. Yep sure, I'm using LibTorch (C++ Torch) though - not PyTorch. See my code here. [Edit]: Realised LibTorch might not be all that helpful - there is no run script for random data much like what we are doing above. Here is the PyTorch code I used for prototyping here. The result is:
|
@HugoPhibbs can you provide a main function too so we can run the whole thing? |
Hi @cliffburdick, Do you mean for the LibTorch? (LibTorch is a little hard to set up) The Pytorch code has code at the bottom you can run. |
I missed that. Thanks |
Gidday.
I'm a bit of a novice with MatX and CPP, and was looking to get some help with optimising my MatX code.
So basically I'm trying to refactor my code that was written in CuPy first into lightning fast MatX code. Except I find that my MatX implementation, despite (what looks to me) an identical equivalent to my CuPy code, it is a lot slower. I was wondering if anybody would be able to give me some tips as to where my code might be slowing down.
FYI a general assumption is that MatX's operators are super lightweight - so the reshapes, repmats are all super quick.
My MatX code looks like:
And the same CuPy code looks like:
The parameters used for both are:
Regarding results, the MatX code takes around 14.5 seconds to complete, but CuPy takes 0.9 seconds (including Cuda Synchronisations).
As a baseline, a multithreaded (64 threads) CPU implementation of the above code (using loops with no tensors involved) takes less than 0.7 seconds. A single threaded CPU implementation takes around 7 seconds - (this is using the same machine of course).
Sorry if the variable names are a little cryptic.
I've tested for around
n = 1000
and found that the two implementations produce the same results (albeit with a small amount of floating point errors).Thanks in advance.
The text was updated successfully, but these errors were encountered: