From e2da1319881d60a7fa7bad989ebd7b93bed8d554 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 1 Jul 2024 11:42:29 +0200 Subject: [PATCH] Default target to CUDA when installed (#88) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- README.md | 4 ++-- lib/xla.ex | 11 ++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index bac075a..b943681 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,8 @@ need to export it in every shell session. #### `XLA_TARGET` -The default value is `cpu`, which implies the final the binary supports targeting -only the host CPU. +The default value is usually `cpu`, which implies the final the binary supports targeting +only the host CPU. If a matching CUDA version is detected, the target is set to CUDA accordingly. | Value | Target environment | | --- | --- | diff --git a/lib/xla.ex b/lib/xla.ex index b02fee0..50c08ca 100644 --- a/lib/xla.ex +++ b/lib/xla.ex @@ -44,7 +44,7 @@ defmodule XLA do end defp xla_target() do - target = System.get_env("XLA_TARGET", "cpu") + target = System.get_env("XLA_TARGET") || infer_xla_target() || "cpu" supported_xla_targets = ["cpu", "cuda", "rocm", "tpu", "cuda12"] @@ -56,6 +56,15 @@ defmodule XLA do target end + defp infer_xla_target() do + with nvcc when nvcc != nil <- System.find_executable("nvcc"), + {output, 0} <- System.cmd(nvcc, ["--version"]) do + if output =~ "release 12.", do: "cuda12" + else + _ -> nil + end + end + defp xla_cache_dir() do # The directory where we store all the archives if dir = System.get_env("XLA_CACHE_DIR") do