Skip to content

UCP/CORE: disable everything (aux as well) when ^ib is configured #10579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions src/ucp/core/ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@
#define UCP_CONTEXT_INFINITE_LAT_FACTOR 100

typedef enum ucp_transports_list_search_result {
UCP_TRANSPORTS_LIST_SEARCH_RESULT_PRIMARY = UCS_BIT(0),
UCP_TRANSPORTS_LIST_SEARCH_RESULT_AUX_IN_MAIN = UCS_BIT(1),
UCP_TRANSPORTS_LIST_SEARCH_RESULT_AUX_IN_ALIAS = UCS_BIT(2)
UCP_TRANSPORTS_LIST_SEARCH_RESULT_PRIMARY = UCS_BIT(0),
UCP_TRANSPORTS_LIST_SEARCH_RESULT_AUX_IN_MAIN = UCS_BIT(1),
UCP_TRANSPORTS_LIST_SEARCH_RESULT_AUX_IN_ALIAS = UCS_BIT(2),
UCP_TRANSPORTS_LIST_SEARCH_RESULT_TL_AND_AUX_IN_ALIAS = UCS_BIT(3)
} ucp_transports_list_search_result_t;


Expand Down Expand Up @@ -689,7 +690,7 @@ static ucp_tl_alias_t ucp_tl_aliases[] = {
{ "sm", { "posix", "sysv", "xpmem", "knem", "cma", NULL } },
{ "shm", { "posix", "sysv", "xpmem", "knem", "cma", NULL } },
{ "ib", { "rc_verbs", "ud_verbs", "rc_mlx5", "ud_mlx5", "dc_mlx5",
"gga_mlx5", NULL } },
"gga_mlx5", UCP_TL_AUX("ud_mlx5"), UCP_TL_AUX("ud_verbs"), NULL } },
{ "ud_v", { "ud_verbs", NULL } },
{ "ud_x", { "ud_mlx5", NULL } },
{ "ud", { "ud_mlx5", "ud_verbs", NULL } },
Expand Down Expand Up @@ -1022,6 +1023,7 @@ ucp_transports_list_search(const char *tl_name,
uint8_t search_result = 0;
uint64_t tmp_tl_cfg_mask;
ucp_tl_alias_t *alias;
int tl_in_alias, tl_aux_in_alias;

if (ucp_config_is_tl_name_present(tl_array, tl_name, 0, NULL,
tl_cfg_mask)) {
Expand All @@ -1039,17 +1041,26 @@ ucp_transports_list_search(const char *tl_name,
tmp_tl_cfg_mask = 0;
if (ucp_config_is_tl_name_present(tl_array, alias->alias, 1, NULL,
&tmp_tl_cfg_mask)) {
if (ucp_tls_alias_is_present(alias, tl_name, NULL)) {
tl_in_alias = ucp_tls_alias_is_present(alias, tl_name, NULL);
if (tl_in_alias) {
/* alias={tl_name}, UCX_TLS=[^]alias */
*tl_cfg_mask |= tmp_tl_cfg_mask;
search_result |= UCP_TRANSPORTS_LIST_SEARCH_RESULT_PRIMARY;
}

if (ucp_tls_alias_is_present(alias, tl_name, UCP_TL_AUX_SUFFIX)) {
tl_aux_in_alias = ucp_tls_alias_is_present(alias, tl_name,
UCP_TL_AUX_SUFFIX);
if (tl_aux_in_alias) {
/* alias={tl_name:aux}, UCX_TLS=[^]alias */
*tl_cfg_mask |= tmp_tl_cfg_mask;
search_result |= UCP_TRANSPORTS_LIST_SEARCH_RESULT_AUX_IN_ALIAS;
}

if (tl_in_alias && tl_aux_in_alias) {
/* alias={tl_name, tl_name:aux}, UCX_TLS=[^]alias */
search_result |=
UCP_TRANSPORTS_LIST_SEARCH_RESULT_TL_AND_AUX_IN_ALIAS;
}
}

tmp_tl_cfg_mask = 0;
Expand Down Expand Up @@ -1108,13 +1119,15 @@ ucp_is_resource_in_transports_list(const char *tl_name,
return !(search_result & UCP_TRANSPORTS_LIST_SEARCH_RESULT_PRIMARY);
}

/* Only explicit indication in the deny list can completely disable
* transport which can be used as an auxiliary.
* E.g: UCX_TLS=^tl_name,tl_name:aux, or alias={tl_name} and
* UCX_TLS=^alias,alias:aux. */
/* A transport that can be used as an auxiliary is disabled by
* including it in the deny list in one of the following ways:
* - UCX_TLS=^tl_name,tl_name:aux
* - UCX_TLS=^alias,alias:aux where alias={tl_name}
* - UCX_TLS=^alias where alias={tl_name,tl_name:aux} */
if (ucs_test_all_flags(search_result,
UCP_TRANSPORTS_LIST_SEARCH_RESULT_PRIMARY |
UCP_TRANSPORTS_LIST_SEARCH_RESULT_AUX_IN_MAIN)) {
UCP_TRANSPORTS_LIST_SEARCH_RESULT_AUX_IN_MAIN) ||
search_result & UCP_TRANSPORTS_LIST_SEARCH_RESULT_TL_AND_AUX_IN_ALIAS) {
return 0;
}

Expand Down
2 changes: 1 addition & 1 deletion src/ucp/core/ucp_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ typedef struct ucp_tl_resource_desc {
*/
typedef struct ucp_tl_alias {
const char *alias; /* Alias name */
const char* tls[8]; /* Transports which are selected by the alias */
const char* tls[10]; /* Transports which are selected by the alias */
} ucp_tl_alias_t;


Expand Down
91 changes: 68 additions & 23 deletions test/apps/test_ucx_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@
"mm": ["posix", "sysv", "xpmem", ],
"sm": ["posix", "sysv", "xpmem", "knem", "cma", "rdmacm", "sockcm", ],
"shm": ["posix", "sysv", "xpmem", "knem", "cma", "rdmacm", "sockcm", ],
"ib": ["rc_verbs", "ud_verbs", "rc_mlx5", "ud_mlx5", "dc_mlx5", "rdmacm", ],
"ib": ["rc_verbs", "ud_verbs", "rc_mlx5", "ud_mlx5", "dc_mlx5", "rdmacm",
"ud_mlx5:aux", "ud_verbs:aux", ],
"ud_v": ["ud_verbs", "rdmacm", ],
"ud_x": ["ud_mlx5", "rdmacm", ],
"ud": ["ud_mlx5", "ud_verbs", "rdmacm", ],
Expand All @@ -100,17 +101,20 @@
}

@contextlib.contextmanager
def _override_env(var_name, value):
if value is None:
yield
return
def _override_env(env_vars):
prev_values = []
for var_name, value in env_vars:
prev_values.append((var_name, os.getenv(var_name)))
if value is not None:
os.putenv(var_name, value)
else:
os.unsetenv(var_name)

prev_value = os.getenv(var_name)
os.putenv(var_name, value)
try:
yield
finally:
os.putenv(var_name, prev_value) if prev_value else os.unsetenv(var_name)
for var_name, prev_value in prev_values:
os.putenv(var_name, prev_value) if prev_value else os.unsetenv(var_name)

def exec_cmd(cmd):
if options.verbose:
Expand All @@ -123,25 +127,40 @@ def exec_cmd(cmd):

return status, output

def find_am_transport(dev, neps=1, override=0, tls="ib"):
def find_transport(dev=None, neps=1, override=0, tls="ib", protocol="am"):
if (override):
os.putenv("UCX_NUM_EPS", "2")

with _override_env("UCX_TLS", tls), \
_override_env("UCX_NET_DEVICES", dev):

status, output = exec_cmd(f"{ucx_info}{ucx_info_args}{neps} | grep am")
env_vars = [("UCX_TLS", tls)]

# Set up environment variables based on protocol type
if protocol == "am" and dev:
env_vars.append(("UCX_NET_DEVICES", dev))

# Use context manager for all environment variables
with _override_env(env_vars):
# Choose the appropriate arguments and grep pattern based on protocol type
if protocol == "keepalive":
args = ucx_info_eh_args
elif protocol == "am": # am transport
args = ucx_info_args

status, output = exec_cmd(f"{ucx_info}{args}{neps} | grep {protocol}")

match = re.search(r'\d+:(\S+)/\S+', output)
if match:
am_tls = match.group(1)
if (override):
proto_tls = match.group(1)
if override:
os.unsetenv("UCX_NUM_EPS")

return am_tls
return proto_tls
else:
return None

def find_am_transport(dev, neps=1, override=0, tls="ib"):
return find_transport(dev=dev, neps=neps, override=override,
tls=tls, protocol="am")

def test_fallback_from_rc(dev, neps) :

os.putenv("UCX_TLS", "ib")
Expand Down Expand Up @@ -185,17 +204,29 @@ def test_ucx_tls_positive(tls):
print("Found TL doesn't belong to the allowed UCX_TLS")
sys.exit(1)

def test_ucx_tls_negative(tls):
# Use TLS list in "negate" mode and verify that the found tl is not in the list
found_tl = find_am_transport(None, tls="^"+tls)
print(f"Using UCX_TLS={tls}, found TL: {found_tl}")
tls = tls.split(',')
if not found_tl or found_tl in tls:
def test_ucx_tls_negative(tls, protocol="am", forbidden_tls=None):
# Use TLS list in "negate" mode
found_tl = find_transport(tls="^"+tls, protocol=protocol)
print(f"Using UCX_TLS=^{tls}, found {protocol} TL: {found_tl}")
if not found_tl:
print("No available TL found")
sys.exit(1)

# If forbidden_tls is provided, verify that the found tl is not in that list
if forbidden_tls is not None:
if found_tl in forbidden_tls:
print(f"Found forbidden TL: {found_tl}")
sys.exit(1)
return

# Otherwise, check against the tls list
tls = tls.split(',')
if found_tl in tls:
print(f"Found forbidden TL: {found_tl}")
sys.exit(1)
for tl in tls:
if tl in tl_aliases and found_tl in tl_aliases[tl]:
print("Found TL belongs to the forbidden UCX_TLS")
print(f"Found forbidden TL: {found_tl}")
sys.exit(1)

def _powerset(iterable, with_empty_set=True):
Expand Down Expand Up @@ -224,6 +255,19 @@ def test_tls_allow_list(ucx_info):
itertools.product(tls_variants, test_funcs):
test_func(",".join(tls_variant))

# Test auxiliary transport negation
test_cases_negative = [
("ib", {"ud_mlx5", "ud_verbs"}),
("ud,ud:aux", {"ud_mlx5", "ud_verbs"}),
("ud_v,ud_v:aux", {"ud_verbs"}),
("ud_x,ud_x:aux", {"ud_mlx5"}),
("ud_verbs,ud_verbs:aux", {"ud_verbs"}),
("ud_mlx5,ud_mlx5:aux", {"ud_mlx5"})
]

for tls, forbidden_tls in test_cases_negative:
test_ucx_tls_negative(tls, protocol="keepalive", forbidden_tls=forbidden_tls)

parser = OptionParser()
parser.add_option("-p", "--prefix", metavar="PATH", help = "root UCX directory")
parser.add_option("-v", "--verbose", action="store_true", \
Expand All @@ -242,6 +286,7 @@ def test_tls_allow_list(ucx_info):

ucx_info = bin_prefix + "/ucx_info"
ucx_info_args = " -e -u t -n "
ucx_info_eh_args = " -e -u et -n "

status, output = exec_cmd(ucx_info + " -c | grep -e \"UCX_RC_.*_MAX_NUM_EPS\"")
match = re.findall(r'\S+=(\d+)', output)
Expand Down
Loading