From b923047feadc0dd74480766fa7cfa0d58f959084 Mon Sep 17 00:00:00 2001 From: Johannes Blaschke Date: Thu, 10 Apr 2025 11:18:39 -0700 Subject: [PATCH 1/2] wip on streamlinging slurm integration --- src/Distributed.jl | 53 +++++++++++++++++++++++++++++----------------- src/Reactant.jl | 2 ++ 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/src/Distributed.jl b/src/Distributed.jl index 776c2e7ac3..25f19a3f6a 100644 --- a/src/Distributed.jl +++ b/src/Distributed.jl @@ -1,6 +1,6 @@ module Distributed -using ..Reactant: Reactant +using ..Reactant: Reactant, Hostlists using Sockets const initialized = Ref(false) @@ -266,27 +266,40 @@ const _SLURM_NUM_NODES = "SLURM_STEP_NUM_NODES" is_env_present(::SlurmEnvDetector) = haskey(ENV, _SLURM_JOB_ID) function get_coordinator_address(::SlurmEnvDetector, ::Integer) - port = parse(Int, ENV[_SLURM_JOB_ID]) % 2^12 + (65535 - 2^12 + 1) - - # Parse the first hostname of the job - # If we are looking for 'node001', - # node_list potential formats are 'node001', 'node001,host2', - # 'node[001-0015],host2', and 'node[001,007-015],host2'. - node_list = ENV[_SLURM_NODELIST] - ind = findfirst(Base.Fix2(in, (',', '[')), node_list) - ind = isnothing(ind) ? length(node_list) + 1 : ind - - if ind == length(node_list) + 1 || node_list[ind] == ',' - # 'node001' or 'node001,host2' - return "$(node_list[1:ind-1]):$(port)" + # port = parse(Int, ENV[_SLURM_JOB_ID]) % 2^12 + (65535 - 2^12 + 1) + # + # # Parse the first hostname of the job + # # If we are looking for 'node001', + # # node_list potential formats are 'node001', 'node001,host2', + # # 'node[001-0015],host2', and 'node[001,007-015],host2'. + # node_list = ENV[_SLURM_NODELIST] + # ind = findfirst(Base.Fix2(in, (',', '[')), node_list) + # ind = isnothing(ind) ? length(node_list) + 1 : ind + + # if ind == length(node_list) + 1 || node_list[ind] == ',' + # # 'node001' or 'node001,host2' + # return "$(node_list[1:ind-1]):$(port)" + # else + # # 'node[001-0015],host2' or 'node[001,007-015],host2' + # prefix = node_list[1:(ind - 1)] + # suffix = node_list[(ind + 1):end] + # ind2 = findfirst(Base.Fix2(in, (',', '-')), suffix) + # ind2 = isnothing(ind2) ? length(suffix) : ind2 + # return "$(prefix)$(suffix[1:ind2-1]):$(port)" + # end + if haskey(ENV, "REACTANT_COORDINATOR_BIND_ADDRESS") + port = ENV["REACTANT_COORDINATOR_BIND_ADDRESS"] |> + Base.Fix2(split, ":") |> last |> Base.Fix1(parse, Int) + @debug "Port: $(port) inferred from REACTANT_COORDINATOR_BIND_ADDRESS" else - # 'node[001-0015],host2' or 'node[001,007-015],host2' - prefix = node_list[1:(ind - 1)] - suffix = node_list[(ind + 1):end] - ind2 = findfirst(Base.Fix2(in, (',', '-')), suffix) - ind2 = isnothing(ind2) ? length(suffix) : ind2 - return "$(prefix)$(suffix[1:ind2-1]):$(port)" + port = parse(Int, ENV[_SLURM_JOB_ID]) % 2^12 + (65535 - 2^12 + 1) + @debug "Port: $(port) inferred from _SLURM_JOB_ID" end + node_list = ENV[_SLURM_NODELIST] + @debug "Setup coordinator: node_list=$(node_list)" + broker_addr = Hostlists.Hostlist(node_list) |> first + @debug "Setup coordinator: broker_addr=$(node_list)" + return "$(broker_addr):$(port)" end get_process_count(::SlurmEnvDetector) = parse(Int, ENV[_SLURM_PROCESS_COUNT]) diff --git a/src/Reactant.jl b/src/Reactant.jl index 5199db4453..2de8a6fd6f 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -58,6 +58,8 @@ include("Devices.jl") include("Interpreter.jl") include("Profiler.jl") include("Types.jl") + +include(joinpath("extern", "hostlists.jl")) include("Distributed.jl") const with_profiler = Profiler.with_profiler From 930ae84c6c49095bae2a845c0cb7b8cefe645830 Mon Sep 17 00:00:00 2001 From: Johannes Blaschke Date: Thu, 10 Apr 2025 11:21:23 -0700 Subject: [PATCH 2/2] add slurm bindings --- src/extern/hostlists.jl | 128 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 src/extern/hostlists.jl diff --git a/src/extern/hostlists.jl b/src/extern/hostlists.jl new file mode 100644 index 0000000000..608f624de5 --- /dev/null +++ b/src/extern/hostlists.jl @@ -0,0 +1,128 @@ +module Hostlists +module SlurmHostlists + +import Libdl + +const libslurm = Libdl.find_library(["libslurm"]) +if !("" == libslurm) + # We need to dlopen libslurm with RTLD_GLOBAL to make sure that all + # dependencies are loaded correctly. + Libdl.dlopen(libslurm, Libdl.RTLD_GLOBAL) +end + +const hostlist_t = Ptr{Nothing} + +slurm_hostlist_create(hostlist) = @ccall libslurm.slurm_hostlist_create(hostlist::Cstring)::hostlist_t +slurm_hostlist_count(hl::hostlist_t) = @ccall libslurm.slurm_hostlist_count(hl::hostlist_t)::Cint +slurm_hostlist_destroy(hl::hostlist_t) = @ccall libslurm.slurm_hostlist_destroy(hl::hostlist_t)::Cvoid +slurm_hostlist_find(hl::hostlist_t, hostname) = @ccall libslurm.slurm_hostlist_find(hl::hostlist_t, hostname::Cstring)::Cint +slurm_hostlist_push(hl::hostlist_t, hosts) = @ccall libslurm.slurm_hostlist_push(hl::hostlist_t,hosts::Cstring)::Cint +slurm_hostlist_push_host(hl::hostlist_t, host) = @ccall libslurm.slurm_hostlist_push_host(hl::hostlist_t, host::Cstring)::Cint +slurm_hostlist_ranged_string(hl::hostlist_t, n::Csize_t, buf) = @ccall libslurm.slurm_hostlist_ranged_string(hl::hostlist_t, n::Csize_t, buf::Ptr{UInt8})::Cssize_t +slurm_hostlist_shift(hl::hostlist_t) = @ccall libslurm.slurm_hostlist_shift(hl::hostlist_t)::Cstring +slurm_hostlist_uniq(hl::hostlist_t) = @ccall libslurm.slurm_hostlist_uniq(hl::hostlist_t)::Cvoid + +mutable struct Hostlist + hlist::hostlist_t + + function Hostlist(node_list::String) + slurm_hl = slurm_hostlist_create(node_list) + if slurm_hl == C_NULL + error("Could not allocate memory for hostlist.") + end + hl = new(slurm_hl) + finalizer(delete, hl) + return hl + end +end + +export Hostlist + +function delete(hl::Hostlist) + slurm_hostlist_destroy(hl.hlist) +end + +function Base.iterate(hl::Hostlist, state::Union{Nothing,Hostlist}=nothing) + hn_cstring = slurm_hostlist_shift(hl.hlist) + (hn_cstring == C_NULL) && return nothing + + return unsafe_string(hn_cstring), hl +end + +Base.eltype(::Hostlist) = String +Base.IteratorEltype(::Type{Hostlist}) = Base.HasEltype() +Base.IteratorSize(::Type{Hostlist}) = Base.SizeUnknown() + +function Base.convert(::Type{String}, hl::Hostlist, init_maxlen=8192) + maxlen = init_maxlen + hostnames = Vector{UInt8}(undef, maxlen) + write_len = 0 + while true + hostnames = Vector{UInt8}(undef, maxlen) + write_len = slurm_hostlist_ranged_string(hl.hlist, UInt64(sizeof(hostnames)), hostnames) + (write_len != -1) && break + maxlen *= 2 + end + hostlist = hostnames[1:write_len+1] + hostlist[end] = 0 # ensure null-termination + return GC.@preserve hostlist unsafe_string(pointer(hostlist)) +end + +Base.string(hl::Hostlist) = Base.convert(String, hl) +function Base.push!(x::Hostlist, y::String) + slurm_hostlist_push(x.hlist, y) + x +end +Base.length(x::Hostlist) = slurm_hostlist_count(x.hlist) + +export string, push!, length + +Base.show(io::IO, x::Hostlist) = print(io, string(x)) + +end + +module SimpleHostlists + +mutable struct Hostlist + hlist::Vector{String} + + function Hostlist(node_list::String) + hl = node_list |> x->split(x, ",") |> x->filter(!isempty, x) |> unique! + return new(hl) + end +end + +Base.eltype(::Hostlist) = String +Base.IteratorEltype(::Type{Hostlist}) = Base.HasEltype() +Base.IteratorSize(::Type{Hostlist}) = Base.SizeUnknown() + +Base.iterate(hl::Hostlist) = Base.iterate(hl.hlist) +Base.iterate(hl::Hostlist, state) = Base.iterate(hl.hlist, state) + +Base.convert(::Type{String}, hl::Hostlist) = join(hl.hlist, ",") +Base.string(hl::Hostlist) = Base.convert(String, hl) +function Base.push!(x::Hostlist, y::String) + push!(x.hlist, y) + x.hlist = x.hlist |> x->filter(!isempty, x) |> unique! + x +end +Base.length(x::Hostlist) = length(x.hlist) + +export string, push!, length + +Base.show(io::IO, x::Hostlist) = print(io, string(x)) + +end + +global Hostlist + +function __init__() + if "" == SlurmHostlists.libslurm + @debug "libslurm.so not found, using SimpleHostlists" + global const Hostlists.Hostlist = SimpleHostlists.Hostlist + else + @debug "libslurm.so found, using SlurmHostlists" + global const Hostlists.Hostlist = SlurmHostlists.Hostlist + end +end +end