Fix the GPU multi host segmentation fault on XPK #129
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
We observe the following segmentation fault when running multi host GPU with xpk script and PR#121. The following code is the culprit:
IMO, we don't really need the jit function here. After removing and calling init_train_state() directly solved the issue. Here are the detailed error messages:
ERROR 2024-10-24T06:38:42.674289966Z [resource.labels.containerName: gpu-image] Thread 0x00007d1b39ade740 (most recent call first):
ERROR 2024-10-24T06:38:42.674291558Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 1287 in call
ERROR 2024-10-24T06:38:42.674293089Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py", line 333 in wrapper
ERROR 2024-10-24T06:38:42.674308967Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1668 in _pjit_call_impl_python
ERROR 2024-10-24T06:38:42.674313541Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1714 in call_impl_cache_miss
ERROR 2024-10-24T06:38:42.674318197Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1738 in _pjit_call_impl
ERROR 2024-10-24T06:38:42.674322419Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 955 in process_primitive
ERROR 2024-10-24T06:38:42.674327226Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 442 in bind_with_trace
ERROR 2024-10-24T06:38:42.674335880Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 2803 in bind
ERROR 2024-10-24T06:38:42.674349131Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 188 in _python_pjit_helper
ERROR 2024-10-24T06:38:42.674371601Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 338 in cache_miss
ERROR 2024-10-24T06:38:42.674385416Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 180 in reraise_with_filtered_traceback
ERROR 2024-10-24T06:38:42.674403826Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/maxdiffusion/max_utils.py", line 415 in setup_initial_state
ERROR 2024-10-24T06:38:42.674407243Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py", line 80 in create_unet_state
ERROR 2024-10-24T06:38:42.674435450Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/maxdiffusion/trainers/base_stable_diffusion_trainer.py", line 86 in start_training
ERROR 2024-10-24T06:38:42.674437794Z [resource.labels.containerName: gpu-image] File "/deps/src/maxdiffusion/train_sdxl.py", line 36 in train
ERROR 2024-10-24T06:38:42.674441074Z [resource.labels.containerName: gpu-image] File "/deps/src/maxdiffusion/train_sdxl.py", line 45 in main
ERROR 2024-10-24T06:38:42.674461026Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254 in _run_main
ERROR 2024-10-24T06:38:42.674465963Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308 in run
ERROR 2024-10-24T06:38:42.674473260Z [resource.labels.containerName: gpu-image] File "/deps/src/maxdiffusion/train_sdxl.py", line 49 in
ERROR 2024-10-24T06:38:42.674516103Z [resource.labels.containerName: gpu-image] File "/usr/lib/python3.10/runpy.py", line 86 in _run_code
ERROR 2024-10-24T06:38:42.674555529Z [resource.labels.containerName: gpu-image] File "/usr/lib/python3.10/runpy.py", line 196 in _run_module_as_main
ERROR 2024-10-24T06:38:42.674662863Z [resource.labels.containerName: gpu-image] {}
ERROR 2024-10-24T06:38:42.675985140Z [resource.labels.containerName: gpu-image] Extension modules: charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, yaml._yaml, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, PIL._imaging, jaxlib.cpu_feature_guard, msgpack._cmsgpack, google.protobuf.pyext._message, grpc._cython.cygrpc, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, h5py._errors, h5py.defs, h5py._objects, h5py.h5, h5py.utils, h5py.h5t, h5py.h5s, h5py.h5ac, h5py.h5p, h5py.h5r, h5py._proxy, h5py._conv, h5py.h5z, h5py.h5a, h5py.h5d, h5py.h5ds, h5py.h5g, h5py.h5i, h5py.h5o, h5py.h5f, h5py.h5fd, h5py.h5pl, h5py.h5l, h5py._selector, scipy._lib._ccallback_c, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, psutil._psutil_linux, psutil._psutil_posix, pyarrow.lib, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._li…
ERROR 2024-10-24T06:39:50.837437957Z [resource.labels.containerName: gpu-image] gpu_multi_process_run.sh: line 152: 43 Segmentation fault (core dumped) python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name=sdxl output_dir=gs://lancewang-dev-supercomputer-testing/maxdiffusion_gpu pretrained_model_name_or_path=gs://lancewang-dev-supercomputer-testing/maxdiffusion_gpu/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0/
INFO 2024-10-24T06:39:55.258113336Z [resource.labels.containerName: gpu-image] PID 24 failed with exit code 1