Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions comms/ncclx/meta/hints/Hints.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
#include "comms/ctran/algos/AllToAll/AllToAllvDynamicHintUtils.h"
#include "comms/ctran/utils/Checks.h"
#include "comms/ctran/window/WinHintUtils.h"
#include "meta/NcclxConfig.h" // @manual
#include "meta/wrapper/MetaFactory.h"

#include <algorithm>

namespace ncclx {

using meta::comms::hints::AllToAllPHintUtils;
Expand All @@ -21,26 +24,51 @@ __attribute__((visibility("default"))) Hints::Hints() {
WinHintUtils::init(this->kv);
}

__attribute__((visibility("default"))) Hints::Hints(
std::initializer_list<std::pair<std::string, std::string>> init)
: Hints() {
for (const auto& [key, val] : init) {
set(key, val);
}
}

// Strip the "ncclx::" prefix from a key if present, so callers can use
// either "fastInitMode" or "ncclx::fastInitMode" interchangeably.
static std::string stripNcclxPrefix(const std::string& key) {
constexpr std::string_view kPrefix = "ncclx::";
if (key.compare(0, kPrefix.size(), kPrefix) == 0) {
return key.substr(kPrefix.size());
}
return key;
}

__attribute__((visibility("default"))) ncclResult_t
Hints::set(const std::string& key, const std::string& val) {
if (key.starts_with("ncclx_alltoallv_dynamic")) {
auto bareKey = stripNcclxPrefix(key);
if (bareKey.starts_with("ncclx_alltoallv_dynamic")) {
NCCLCHECK(
metaCommToNccl(AllToAllvDynamicHintUtils::set(key, val, this->kv)));
metaCommToNccl(AllToAllvDynamicHintUtils::set(bareKey, val, this->kv)));
return ncclSuccess;
} else if (key.starts_with("ncclx_alltoallp")) {
NCCLCHECK(metaCommToNccl(AllToAllPHintUtils::set(key, val, this->kv)));
} else if (bareKey.starts_with("ncclx_alltoallp")) {
NCCLCHECK(metaCommToNccl(AllToAllPHintUtils::set(bareKey, val, this->kv)));
return ncclSuccess;
} else if (key.starts_with(("window"))) {
NCCLCHECK(metaCommToNccl(WinHintUtils::set(key, val, this->kv)));
} else if (bareKey.starts_with(("window"))) {
NCCLCHECK(metaCommToNccl(WinHintUtils::set(bareKey, val, this->kv)));
return ncclSuccess;
} else {
return ncclInvalidArgument;
const auto& knownKeys = ncclx::knownHintKeys();
if (std::find(knownKeys.begin(), knownKeys.end(), bareKey) ==
knownKeys.end()) {
WARN("NCCLX Hints: unknown key '%s'; check spelling", bareKey.c_str());
}
this->kv[bareKey] = val;
return ncclSuccess;
}
}

__attribute__((visibility("default"))) ncclResult_t
Hints::get(const std::string& key, std::string& val) const {
auto iter = this->kv.find(key);
auto iter = this->kv.find(stripNcclxPrefix(key));
if (iter != this->kv.end()) {
val = iter->second;
return ncclSuccess;
Expand Down
209 changes: 209 additions & 0 deletions comms/ncclx/meta/hints/tests/ConfigHintsUT.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

#include <gtest/gtest.h>
#include <string>
#include <vector>

#include "nccl.h" // @manual

#include "meta/NcclxConfig.h" // @manual

// ----- ncclxParseCommConfig tests -----

TEST(ConfigHintsUT, NoHintsCreatesDefaults) {
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
// hints is (void*)NCCL_CONFIG_UNDEF_PTR by default
EXPECT_EQ(ncclxParseCommConfig(&config), ncclSuccess);

// ncclx::Config should be created with defaults
ASSERT_NE(config.ncclxConfig, (void*)NCCL_CONFIG_UNDEF_PTR);
ASSERT_NE(config.ncclxConfig, nullptr);

auto* ncclxCfg = static_cast<ncclx::Config*>(config.ncclxConfig);
EXPECT_EQ(ncclxCfg->commDesc, "undefined");
EXPECT_TRUE(ncclxCfg->splitGroupRanks.empty());
EXPECT_EQ(ncclxCfg->ncclAllGatherAlgo, "undefined");
EXPECT_FALSE(ncclxCfg->lazyConnect);

// Upstream NCCL fields should be untouched
EXPECT_EQ(config.blocking, NCCL_CONFIG_UNDEF_INT);
EXPECT_EQ(config.cgaClusterSize, NCCL_CONFIG_UNDEF_INT);

delete ncclxCfg;
}

TEST(ConfigHintsUT, HintsCreateNcclxConfig) {
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
ncclx::Hints hints;
hints.set("commDesc", "test_desc");
hints.set("lazyConnect", "1");
hints.set("lazySetupChannels", "0");
hints.set("fastInitMode", "1");
hints.set("ncclAllGatherAlgo", "custom_algo");
config.hints = &hints;

EXPECT_EQ(ncclxParseCommConfig(&config), ncclSuccess);

ASSERT_NE(config.ncclxConfig, (void*)NCCL_CONFIG_UNDEF_PTR);
ASSERT_NE(config.ncclxConfig, nullptr);

EXPECT_EQ(NCCLX_CONFIG_FIELD(config, commDesc), "test_desc");
EXPECT_TRUE(NCCLX_CONFIG_FIELD(config, lazyConnect));
EXPECT_FALSE(NCCLX_CONFIG_FIELD(config, lazySetupChannels));
EXPECT_TRUE(NCCLX_CONFIG_FIELD(config, fastInitMode));
EXPECT_EQ(NCCLX_CONFIG_FIELD(config, ncclAllGatherAlgo), "custom_algo");

// Upstream NCCL fields should be untouched
EXPECT_EQ(config.blocking, NCCL_CONFIG_UNDEF_INT);

delete static_cast<ncclx::Config*>(config.ncclxConfig);
}

TEST(ConfigHintsUT, PrefixedKeysMatchBareKeys) {
// Set hints using "ncclx::" prefix — should produce the same config
// as bare keys (tested in HintsCreateNcclxConfig above).
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
ncclx::Hints hints;
hints.set("ncclx::commDesc", "test_desc");
hints.set("ncclx::lazyConnect", "1");
hints.set("ncclx::lazySetupChannels", "0");
hints.set("ncclx::fastInitMode", "1");
hints.set("ncclx::ncclAllGatherAlgo", "custom_algo");
config.hints = &hints;

EXPECT_EQ(ncclxParseCommConfig(&config), ncclSuccess);

ASSERT_NE(config.ncclxConfig, (void*)NCCL_CONFIG_UNDEF_PTR);
ASSERT_NE(config.ncclxConfig, nullptr);

EXPECT_EQ(NCCLX_CONFIG_FIELD(config, commDesc), "test_desc");
EXPECT_TRUE(NCCLX_CONFIG_FIELD(config, lazyConnect));
EXPECT_FALSE(NCCLX_CONFIG_FIELD(config, lazySetupChannels));
EXPECT_TRUE(NCCLX_CONFIG_FIELD(config, fastInitMode));
EXPECT_EQ(NCCLX_CONFIG_FIELD(config, ncclAllGatherAlgo), "custom_algo");

// Also verify get() with prefixed key returns the same value
std::string val;
EXPECT_EQ(hints.get("ncclx::commDesc", val), ncclSuccess);
EXPECT_EQ(val, "test_desc");
// And get() with bare key still works
EXPECT_EQ(hints.get("commDesc", val), ncclSuccess);
EXPECT_EQ(val, "test_desc");

delete static_cast<ncclx::Config*>(config.ncclxConfig);
}

TEST(ConfigHintsUT, BoolHintFormats) {
// Test various truthy values
for (const char* trueVal :
{"1", "yes", "YES", "Yes", "true", "TRUE", "True", "y", "Y", "t", "T"}) {
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
ncclx::Hints hints;
hints.set("lazyConnect", trueVal);
config.hints = &hints;
EXPECT_EQ(ncclxParseCommConfig(&config), ncclSuccess) << trueVal;
EXPECT_TRUE(NCCLX_CONFIG_FIELD(config, lazyConnect)) << trueVal;
delete static_cast<ncclx::Config*>(config.ncclxConfig);
}
// Test various falsy values
for (const char* falseVal :
{"0", "no", "NO", "No", "false", "FALSE", "False", "n", "N", "f", "F"}) {
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
ncclx::Hints hints;
hints.set("lazyConnect", falseVal);
config.hints = &hints;
EXPECT_EQ(ncclxParseCommConfig(&config), ncclSuccess) << falseVal;
EXPECT_FALSE(NCCLX_CONFIG_FIELD(config, lazyConnect)) << falseVal;
delete static_cast<ncclx::Config*>(config.ncclxConfig);
}
}

TEST(ConfigHintsUT, OldFormatFlatFields) {
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
// Set fields via old format (directly on ncclConfig_t)
config.commDesc = "old_desc";
config.lazyConnect = 1;
config.fastInitMode = 2;

EXPECT_EQ(ncclxParseCommConfig(&config), ncclSuccess);

ASSERT_NE(config.ncclxConfig, (void*)NCCL_CONFIG_UNDEF_PTR);
ASSERT_NE(config.ncclxConfig, nullptr);

EXPECT_EQ(NCCLX_CONFIG_FIELD(config, commDesc), "old_desc");
EXPECT_TRUE(NCCLX_CONFIG_FIELD(config, lazyConnect));
EXPECT_TRUE(NCCLX_CONFIG_FIELD(config, fastInitMode));

delete static_cast<ncclx::Config*>(config.ncclxConfig);
}

TEST(ConfigHintsUT, ConflictReturnsError) {
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
// Set lazyConnect in old format
config.lazyConnect = 1;
// Also set it in hints (new format)
ncclx::Hints hints;
hints.set("lazyConnect", "0");
config.hints = &hints;

EXPECT_EQ(ncclxParseCommConfig(&config), ncclInvalidArgument);

// ncclxConfig should NOT have been created
EXPECT_EQ(config.ncclxConfig, (void*)NCCL_CONFIG_UNDEF_PTR);
}

TEST(ConfigHintsUT, DoubleParseReturnsError) {
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
ncclx::Hints hints;
hints.set("commDesc", "first_call");
config.hints = &hints;

EXPECT_EQ(ncclxParseCommConfig(&config), ncclSuccess);

ASSERT_NE(config.ncclxConfig, (void*)NCCL_CONFIG_UNDEF_PTR);
ASSERT_NE(config.ncclxConfig, nullptr);
EXPECT_EQ(NCCLX_CONFIG_FIELD(config, commDesc), "first_call");

// Second call must fail — ncclxParseCommConfig must be called exactly once
EXPECT_EQ(ncclxParseCommConfig(&config), ncclInvalidArgument);

delete static_cast<ncclx::Config*>(config.ncclxConfig);
}

// ----- splitGroupRanks tests -----

TEST(ConfigHintsUT, SplitGroupRanksSetViaHints) {
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
ncclx::Hints hints;
hints.set("splitGroupRanks", "0,1,2,3");
config.hints = &hints;

EXPECT_EQ(ncclxParseCommConfig(&config), ncclSuccess);

ASSERT_NE(config.ncclxConfig, (void*)NCCL_CONFIG_UNDEF_PTR);
ASSERT_NE(config.ncclxConfig, nullptr);

auto* ncclxCfg = static_cast<ncclx::Config*>(config.ncclxConfig);
const std::vector<int> expected = {0, 1, 2, 3};
EXPECT_EQ(ncclxCfg->splitGroupRanks, expected);

delete ncclxCfg;
}

TEST(ConfigHintsUT, SplitGroupRanksSingleRank) {
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
ncclx::Hints hints;
hints.set("splitGroupRanks", "7");
config.hints = &hints;

EXPECT_EQ(ncclxParseCommConfig(&config), ncclSuccess);

ASSERT_NE(config.ncclxConfig, (void*)NCCL_CONFIG_UNDEF_PTR);
ASSERT_NE(config.ncclxConfig, nullptr);

auto* ncclxCfg = static_cast<ncclx::Config*>(config.ncclxConfig);
const std::vector<int> expected = {7};
EXPECT_EQ(ncclxCfg->splitGroupRanks, expected);

delete ncclxCfg;
}
3 changes: 2 additions & 1 deletion comms/ncclx/v2_27/examples/HelloWorld.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ int main(int argc, char* argv[]) {
cudaStream_t stream;
int* userBuff = NULL;
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
config.commDesc = "example_pg";
ncclx::Hints hints({{"commDesc", "example_pg"}});
config.hints = &hints;

ncclUniqueId ncclId;
NCCLCHECK(ncclGetUniqueId(&ncclId));
Expand Down
Loading
Loading