Skip to content

Commit 569c115

Browse files
authored
[SymMem 2/5] SymmetricTensor runtime type (#5517)
- #5516 - **Here ==> #5517** - #5518 - #5519 - #5520 Full branch for reference: #5515
1 parent 00f626d commit 569c115

File tree

7 files changed

+881
-28
lines changed

7 files changed

+881
-28
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ list(APPEND NVFUSER_SRCS
313313
${NVFUSER_SRCS_DIR}/multidevice/propagation.cpp
314314
${NVFUSER_SRCS_DIR}/multidevice/resharding.cpp
315315
${NVFUSER_SRCS_DIR}/multidevice/utils.cpp
316+
${NVFUSER_SRCS_DIR}/multidevice/symmetric_tensor.cpp
316317
${NVFUSER_SRCS_DIR}/mutator.cpp
317318
${NVFUSER_SRCS_DIR}/ops/alias.cpp
318319
${NVFUSER_SRCS_DIR}/ops/arith.cpp
@@ -1188,6 +1189,7 @@ if(BUILD_TEST)
11881189
${NVFUSER_ROOT}/tests/cpp/test_multidevice_stream_parallel_type.cpp
11891190
${NVFUSER_ROOT}/tests/cpp/test_multidevice_transformer.cpp
11901191
${NVFUSER_ROOT}/tests/cpp/test_multidevice_ipc.cpp
1192+
${NVFUSER_ROOT}/tests/cpp/test_multidevice_symmetric_tensor.cpp
11911193
)
11921194
add_test_without_main(test_multidevice "${MULTIDEVICE_TEST_SRCS}" "")
11931195
list(APPEND TEST_BINARIES test_multidevice)

csrc/multidevice/ipc_handle.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,10 @@
88
#include <cuda_utils.h>
99
#include <multidevice/communicator.h>
1010
#include <multidevice/ipc_handle.h>
11+
#include <multidevice/utils.h>
1112

1213
namespace nvfuser {
1314

14-
namespace {
15-
16-
template <typename T>
17-
std::vector<uint8_t> toBytes(const T& data) {
18-
return std::vector<uint8_t>(
19-
reinterpret_cast<const uint8_t*>(&data),
20-
reinterpret_cast<const uint8_t*>(&data) + sizeof(T));
21-
}
22-
23-
template <typename T>
24-
const T& fromBytes(const std::vector<uint8_t>& bytes) {
25-
return *reinterpret_cast<const T*>(bytes.data());
26-
}
27-
28-
} // namespace
29-
3015
IpcHandle::IpcHandle(at::Tensor tensor)
3116
: ptr_(tensor.data_ptr()),
3217
rank_(Communicator::getInstance().deviceId()),

0 commit comments

Comments
 (0)