Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the GPU multi host segmentation fault on XPK #129

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

wang2yn84
Copy link
Collaborator

We observe the following segmentation fault when running multi host GPU with xpk script and PR#121. The following code is the culprit:

  state = jax.jit(
      init_train_state_partial,
      in_shardings=None,
      out_shardings=state_mesh_shardings,
  )()

IMO, we don't really need the jit function here. After removing and calling init_train_state() directly solved the issue. Here are the detailed error messages:

ERROR 2024-10-24T06:38:42.674289966Z [resource.labels.containerName: gpu-image] Thread 0x00007d1b39ade740 (most recent call first):
ERROR 2024-10-24T06:38:42.674291558Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 1287 in call
ERROR 2024-10-24T06:38:42.674293089Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py", line 333 in wrapper
ERROR 2024-10-24T06:38:42.674308967Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1668 in _pjit_call_impl_python
ERROR 2024-10-24T06:38:42.674313541Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1714 in call_impl_cache_miss
ERROR 2024-10-24T06:38:42.674318197Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1738 in _pjit_call_impl
ERROR 2024-10-24T06:38:42.674322419Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 955 in process_primitive
ERROR 2024-10-24T06:38:42.674327226Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 442 in bind_with_trace
ERROR 2024-10-24T06:38:42.674335880Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 2803 in bind
ERROR 2024-10-24T06:38:42.674349131Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 188 in _python_pjit_helper
ERROR 2024-10-24T06:38:42.674371601Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 338 in cache_miss
ERROR 2024-10-24T06:38:42.674385416Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 180 in reraise_with_filtered_traceback
ERROR 2024-10-24T06:38:42.674403826Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/maxdiffusion/max_utils.py", line 415 in setup_initial_state
ERROR 2024-10-24T06:38:42.674407243Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py", line 80 in create_unet_state
ERROR 2024-10-24T06:38:42.674435450Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/maxdiffusion/trainers/base_stable_diffusion_trainer.py", line 86 in start_training
ERROR 2024-10-24T06:38:42.674437794Z [resource.labels.containerName: gpu-image] File "/deps/src/maxdiffusion/train_sdxl.py", line 36 in train
ERROR 2024-10-24T06:38:42.674441074Z [resource.labels.containerName: gpu-image] File "/deps/src/maxdiffusion/train_sdxl.py", line 45 in main
ERROR 2024-10-24T06:38:42.674461026Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254 in _run_main
ERROR 2024-10-24T06:38:42.674465963Z [resource.labels.containerName: gpu-image] File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308 in run
ERROR 2024-10-24T06:38:42.674473260Z [resource.labels.containerName: gpu-image] File "/deps/src/maxdiffusion/train_sdxl.py", line 49 in
ERROR 2024-10-24T06:38:42.674516103Z [resource.labels.containerName: gpu-image] File "/usr/lib/python3.10/runpy.py", line 86 in _run_code
ERROR 2024-10-24T06:38:42.674555529Z [resource.labels.containerName: gpu-image] File "/usr/lib/python3.10/runpy.py", line 196 in _run_module_as_main
ERROR 2024-10-24T06:38:42.674662863Z [resource.labels.containerName: gpu-image] {}
ERROR 2024-10-24T06:38:42.675985140Z [resource.labels.containerName: gpu-image] Extension modules: charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, yaml._yaml, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, PIL._imaging, jaxlib.cpu_feature_guard, msgpack._cmsgpack, google.protobuf.pyext._message, grpc._cython.cygrpc, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, h5py._errors, h5py.defs, h5py._objects, h5py.h5, h5py.utils, h5py.h5t, h5py.h5s, h5py.h5ac, h5py.h5p, h5py.h5r, h5py._proxy, h5py._conv, h5py.h5z, h5py.h5a, h5py.h5d, h5py.h5ds, h5py.h5g, h5py.h5i, h5py.h5o, h5py.h5f, h5py.h5fd, h5py.h5pl, h5py.h5l, h5py._selector, scipy._lib._ccallback_c, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, psutil._psutil_linux, psutil._psutil_posix, pyarrow.lib, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._li…
ERROR 2024-10-24T06:39:50.837437957Z [resource.labels.containerName: gpu-image] gpu_multi_process_run.sh: line 152: 43 Segmentation fault (core dumped) python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name=sdxl output_dir=gs://lancewang-dev-supercomputer-testing/maxdiffusion_gpu pretrained_model_name_or_path=gs://lancewang-dev-supercomputer-testing/maxdiffusion_gpu/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0/
INFO 2024-10-24T06:39:55.258113336Z [resource.labels.containerName: gpu-image] PID 24 failed with exit code 1

@wang2yn84 wang2yn84 requested a review from entrpn October 25, 2024 18:18
@@ -407,8 +406,6 @@ def setup_initial_state(
eval_only=False,
)

state = jax.jit(init_train_state_partial, in_shardings=None, out_shardings=state_mesh_shardings)()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wang2yn84 the jit function here shards the models according to the sharding specification in the config yml. In DDP, this won't make a difference since the state is replicated, but removing this function would prevent running FSDP or other sharding configurations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants