diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 1807aaa106..af11ada34c 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -25,11 +25,8 @@ get_frameworks, cuda_path, get_max_jobs_for_parallel_build, - install_and_import, ) -install_and_import("pybind11[global]") - class CMakeExtension(setuptools.Extension): """CMake extension module""" diff --git a/build_tools/utils.py b/build_tools/utils.py index 4bb18a7002..948a3ec9f3 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -255,9 +255,7 @@ def get_frameworks() -> List[str]: _frameworks = [framework.lower() for framework in _frameworks] for framework in _frameworks: if framework not in supported_frameworks: - raise ValueError( - f"Transformer Engine does not support framework={framework}" - ) + raise ValueError(f"Transformer Engine does not support framework={framework}") return _frameworks @@ -297,17 +295,10 @@ def copy_common_headers( shutil.copy(path, new_path) -def pip_or_uv() -> List[str]: - if find_spec("pip") is not None: - return [sys.executable, "-m", "pip"] - else: - return ["/usr/bin/env", "uv", "pip"] - - def install_and_import(package): """Install a package via pip (if not already installed) and import into globals.""" main_package = package.split("[")[0] - subprocess.check_call([*pip_or_uv(), "install", package]) + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) globals()[main_package] = importlib.import_module(main_package) diff --git a/setup.py b/setup.py index b020fa3c4f..3bb2fe6b95 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,6 @@ elif "paddle" in frameworks: from paddle.utils.cpp_extension import BuildExtension elif "jax" in frameworks: - install_and_import("jax[cuda12_local]") install_and_import("pybind11[global]") from pybind11.setup_helpers import build_ext as BuildExtension