Description
Because TensorFlow.jl
with_device()
expects 1-based device numberings, we cannot natively use the output of DeviceList()
to be fed into with_device()
, as DeviceList()
gives zero-indexed device names.
My current workaround is to do something like the following:
function get_device(sess, device_type)
# Find first device with the given device type (e.g. `"XLA_GPU"`)
devices = collect(TensorFlow.DeviceList(sess))
device = first(filter(x -> x.device_type == device_type, devices))
# Fixup this device name so that it is 1-indexed, as TensorFlow.jl requires
inc_number(x::Number) = x + 1
inc_number(x) = x
function fixup_device!(d::TensorFlow.Device)
d.parts[:] .= [TensorFlow.DevicePart(p.kind, inc_number(p.index)) for p in d.parts]
return d
end
fixup_device!(d) = d
return fixup_device!(TensorFlow.Device(device.name))
end
Personally, I would prefer that with_device
used zero-indexed device names, as on systems with multiple devices (e.g. multiple GPUs), it adds an unnecessary extra mental burden to always remember that /job:job/replica:1/task:1/device:CPU:1
in TensorFlow.jl
is not the same thing as /job:job/replica:1/task:1/device:CPU:1
when dealing with anything else in the TensorFlow ecosystem. Regardless, we should be consistent so that the output of one function can be fed to another within TensorFlow.jl
.