Skip to content

Commit

Permalink
revert to just uninstall_te_wheel_packages change
Browse files Browse the repository at this point in the history
  • Loading branch information
jennifgcrl committed Nov 8, 2024
1 parent e821d5a commit 4170941
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 15 deletions.
3 changes: 0 additions & 3 deletions build_tools/build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@
get_frameworks,
cuda_path,
get_max_jobs_for_parallel_build,
install_and_import,
)

install_and_import("pybind11[global]")


class CMakeExtension(setuptools.Extension):
"""CMake extension module"""
Expand Down
13 changes: 2 additions & 11 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,7 @@ def get_frameworks() -> List[str]:
_frameworks = [framework.lower() for framework in _frameworks]
for framework in _frameworks:
if framework not in supported_frameworks:
raise ValueError(
f"Transformer Engine does not support framework={framework}"
)
raise ValueError(f"Transformer Engine does not support framework={framework}")

return _frameworks

Expand Down Expand Up @@ -297,17 +295,10 @@ def copy_common_headers(
shutil.copy(path, new_path)


def pip_or_uv() -> List[str]:
if find_spec("pip") is not None:
return [sys.executable, "-m", "pip"]
else:
return ["/usr/bin/env", "uv", "pip"]


def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals."""
main_package = package.split("[")[0]
subprocess.check_call([*pip_or_uv(), "install", package])
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
globals()[main_package] = importlib.import_module(main_package)


Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import("jax[cuda12_local]")
install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension

Expand Down

0 comments on commit 4170941

Please sign in to comment.