Skip to content

Commit

Permalink
update transformers test requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Nov 21, 2024
1 parent e430795 commit b7cc8e8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
16 changes: 16 additions & 0 deletions tools/ci_build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,21 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs):
if is_windows():
cwd = os.path.join(cwd, config)

# Install PyTorch which is required for transformers tests, and optional for some python tests.
if args.enable_transformers_tool_test and not args.disable_contrib_ops and not args.use_rocm:
index_url = "https://download.pytorch.org/whl/cpu"
if args.use_cuda and is_linux():
index_url = "https://download.pytorch.org/whl/cu124"
if args.cuda_version and version_to_tuple(args.cuda_version) < (12, 0):
index_url = "https://download.pytorch.org/whl/cu118"

run_subprocess(
[sys.executable, "-m", "pip", "install", f"torch", "--index-url", index_url],
cwd=cwd,
dll_path=dll_path,
python_path=python_path,
)

run_subprocess(
[sys.executable, "onnxruntime_test_python.py"], cwd=cwd, dll_path=dll_path, python_path=python_path
)
Expand Down Expand Up @@ -2128,6 +2143,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs):
dll_path=dll_path,
python_path=python_path,
)

if not args.disable_contrib_ops:
run_subprocess(
[sys.executable, "-m", "unittest", "discover", "-s", "quantization"], cwd=cwd, dll_path=dll_path
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# packages used by transformers python unittest (only enabled in Linux CPU CI Pipeline)
packaging
protobuf==3.20.2
numpy==1.24.0 ; python_version < '3.12'
numpy==1.26.0 ; python_version >= '3.12'
# protobuf and numpy is same as tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt
protobuf==4.21.12
numpy==1.21.6 ; python_version < '3.9'
numpy==2.0.0 ; python_version >= '3.9'
torch
coloredlogs==15.0
transformers==4.38.0
Expand Down

0 comments on commit b7cc8e8

Please sign in to comment.