Skip to content

wip: use torch from a wheel #9340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ python_configure(
################################ PyTorch Setup ################################

load("//bazel:dependencies.bzl", "PYTORCH_LOCAL_DIR")
load("//bazel:torch_repo.bzl", "torch_repo")

new_local_repository(
torch_repo(
name = "torch",
build_file = "//bazel:torch.BUILD",
path = PYTORCH_LOCAL_DIR,
dist_dir = "../dist",
)

############################# OpenXLA Setup ###############################

# To build PyTorch/XLA with a new revison of OpenXLA, update the xla_hash to
# the openxla git commit hash and note the date of the commit.
xla_hash = '9ac36592456e7be0d66506be75fbdacc90dd4e91' # Committed on 2025-06-11.
xla_hash = "9ac36592456e7be0d66506be75fbdacc90dd4e91" # Committed on 2025-06-11.

http_archive(
name = "xla",
Expand All @@ -66,8 +66,6 @@ http_archive(
],
)



# For development, one often wants to make changes to the OpenXLA repository as well
# as the PyTorch/XLA repository. You can override the pinned repository above with a
# local checkout by either:
Expand All @@ -89,14 +87,14 @@ python_init_rules()
load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")

python_init_repositories(
default_python_version = "system",
local_wheel_workspaces = ["@torch//:WORKSPACE"],
requirements = {
"3.8": "//:requirements_lock_3_8.txt",
"3.9": "//:requirements_lock_3_9.txt",
"3.10": "//:requirements_lock_3_10.txt",
"3.11": "//:requirements_lock_3_11.txt",
},
local_wheel_workspaces = ["@torch//:WORKSPACE"],
default_python_version = "system",
)

load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")
Expand All @@ -111,8 +109,6 @@ load("@pypi//:requirements.bzl", "install_deps")

install_deps()



# Initialize OpenXLA's external dependencies.
load("@xla//:workspace4.bzl", "xla_workspace4")

Expand All @@ -134,7 +130,6 @@ load("@xla//:workspace0.bzl", "xla_workspace0")

xla_workspace0()


load(
"@xla//third_party/gpus:cuda_configure.bzl",
"cuda_configure",
Expand Down
58 changes: 0 additions & 58 deletions bazel/torch.BUILD

This file was deleted.

100 changes: 100 additions & 0 deletions bazel/torch_repo.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Repository rule to setup a torch repo."""

_BUILD_TEMPLATE = """
# Generated by //bazel:torch_repo.bzl

load("@//bazel:torch_targets.bzl", "define_torch_targets")

package(
default_visibility = [
"//visibility:public",
],
)

define_torch_targets()
""".lstrip()

def _get_url_basename(url):
basename = url.rpartition("/")[2]

# Starlark doesn't have any URL decode functions, so just approximate
# one with the cases we see.
return basename.replace("%2B", "+")

def _torch_repo_impl(rctx):
rctx.file("BUILD.bazel", _BUILD_TEMPLATE)

env_torch_whl = rctx.os.environ.get("TORCH_WHL", "")

urls = None
local_path = None
if env_torch_whl:
if env_torch_whl.startswith("http"):
urls = [env_torch_whl]
else:
local_path = rctx.path(env_torch_whl)
else:
dist_dir = rctx.workspace_root.get_child(rctx.attr.dist_dir)

if dist_dir.exists:
for child in dist_dir.readdir():
# For lack of a better option, take the first match
if child.basename.endswith(".whl"):
local_path = child
break

if not local_path and not urls:
fail((
"No torch wheel source configured:\n" +
"* Set TORCH_WHL environment variable to a local path or URL.\n" +
"* Or ensure the {dist_dir} directory is present with a torch wheel." +
"\n"
).format(
dist_dir = dist_dir,
))

if local_path:
whl_path = local_path
if not whl_path.exists:
fail("File not found: {}".format(whl_path))

# The dist/ directory is necessary for XLA's python_init_repositories
# to discover the wheel and add it to requirements.txt
rctx.symlink(whl_path, "dist/{}".format(whl_path.basename))
elif urls:
whl_basename = _get_url_basename(urls[0])

# The dist/ directory is necessary for XLA's python_init_repositories
# to discover the wheel and add it to requirements.txt
whl_path = rctx.path("dist/{}".format(whl_basename))
result = rctx.download(
url = urls,
output = whl_path,
)
if not result.success:
fail("Failed to download: {}", urls)

# Extract into the repo root. Also use .zip as the extension so that extract
# recognizes the file type.
# Use the whl basename so progress messages are more informative.
whl_zip = whl_path.basename.replace(".whl", ".zip")
rctx.symlink(whl_path, whl_zip)
rctx.extract(whl_zip)
rctx.delete(whl_zip)

torch_repo = repository_rule(
implementation = _torch_repo_impl,
doc = """
Creates a repository with torch headers, shared libraries, and wheel
for integration with Bazel.
""",
attrs = {
"dist_dir": attr.string(
doc = """
Directory with a prebuilt torch wheel. Typically points to a source checkout
that built a torch wheel. Relative paths are relative to the workspace root.
""",
),
},
environ = ["TORCH_WHL"],
)
64 changes: 64 additions & 0 deletions bazel/torch_repo_targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Handles the loading phase to define targets for torch_repo."""

cc_library = native.cc_library

def define_torch_targets():
cc_library(
name = "headers",
hdrs = native.glob(
["torch/include/**/*.h"],
["torch/include/google/protobuf/**/*.h"],
),
strip_include_prefix = "torch/include",
)

# Runtime headers, for importing <torch/torch.h>.
cc_library(
name = "runtime_headers",
hdrs = native.glob(["torch/include/torch/csrc/api/include/**/*.h"]),
strip_include_prefix = "torch/include/torch/csrc/api/include",
)

native.filegroup(
name = "torchgen_deps",
srcs = [
# torchgen/packaged/ instead of aten/src
"torchgen/packaged/ATen/native/native_functions.yaml",
"torchgen/packaged/ATen/native/tags.yaml",
##"torchgen/packaged/ATen/native/ts_native_functions.yaml",
"torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp",
"torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h",
"torchgen/packaged/ATen/templates/LazyIr.h",
"torchgen/packaged/ATen/templates/LazyNonNativeIr.h",
"torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini",
"torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp",
# Add torch/include prefix
"torch/include/torch/csrc/lazy/core/shape_inference.h",
##"torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
],
)

# Changed to cc_library from cc_import

cc_library(
name = "libtorch",
srcs = ["torch/lib/libtorch.so"],
)

cc_library(
name = "libtorch_cpu",
srcs = ["torch/lib/libtorch_cpu.so"],
)

cc_library(
name = "libtorch_python",
srcs = [
"torch/lib/libshm.so", # libtorch_python.so depends on this
"torch/lib/libtorch_python.so",
],
)

cc_library(
name = "libc10",
srcs = ["torch/lib/libc10.so"],
)
6 changes: 3 additions & 3 deletions codegen/lazy_tensor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
kernel_signature,
)

aten_path = os.path.join(torch_root, "aten", "src", "ATen")
shape_inference_hdr = os.path.join(torch_root, "torch", "csrc", "lazy", "core",
"shape_inference.h")
aten_path = os.path.join(torch_root, "torchgen", "packaged", "ATen")
shape_inference_hdr = os.path.join(torch_root, "torch", "include",
"torch", "csrc", "lazy", "core", "shape_inference.h")
impl_path = os.path.join(xla_root, "__main__",
"torch_xla/csrc/aten_xla_type.cpp")
source_yaml = sys.argv[2]
Expand Down