Skip to content

Commit

Permalink
Default target to CUDA when installed (#88)
Browse files Browse the repository at this point in the history
Co-authored-by: José Valim <[email protected]>
  • Loading branch information
jonatanklosko and josevalim committed Jul 1, 2024
1 parent 028e839 commit e2da131
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ need to export it in every shell session.

#### `XLA_TARGET`

The default value is `cpu`, which implies the final the binary supports targeting
only the host CPU.
The default value is usually `cpu`, which implies the final the binary supports targeting
only the host CPU. If a matching CUDA version is detected, the target is set to CUDA accordingly.

| Value | Target environment |
| --- | --- |
Expand Down
11 changes: 10 additions & 1 deletion lib/xla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ defmodule XLA do
end

defp xla_target() do
target = System.get_env("XLA_TARGET", "cpu")
target = System.get_env("XLA_TARGET") || infer_xla_target() || "cpu"

supported_xla_targets = ["cpu", "cuda", "rocm", "tpu", "cuda12"]

Expand All @@ -56,6 +56,15 @@ defmodule XLA do
target
end

defp infer_xla_target() do
with nvcc when nvcc != nil <- System.find_executable("nvcc"),
{output, 0} <- System.cmd(nvcc, ["--version"]) do
if output =~ "release 12.", do: "cuda12"
else
_ -> nil
end
end

defp xla_cache_dir() do
# The directory where we store all the archives
if dir = System.get_env("XLA_CACHE_DIR") do
Expand Down

0 comments on commit e2da131

Please sign in to comment.