You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
I'm trying to run the PyTorch training implementation on an Apple M2 chip with MPS. I can run StyleGAN-ADA image generation following these steps but when I try to train DiffAugment I get this error:
/AppleInternal/Library/BuildRoots/20d6c351-ee94-11ec-bcaf-7247572f23b4/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSNDArray/Kernels/MPSNDArrayConvolutionA14.mm:3237: failed assertion `destination datatype must be fp32'
My steps so far:
Clone the repo and cd data-efficient-gans/DiffAugment-stylegan2-pytorch
I replace all instances of torch.device('cuda') with torch.device('mps')
I replace random array generation with random_array = np.random.RandomState(seed).randn(1, G.z_dim).astype(np.float32) in generate.py as described.
In training_loop.py I replace instances of torch.cuda.Event(enable_timing=True) with time.perf_counter()
I remove torch.backends.cuda.matmul.allow_tf32 = allow_tf32, torch.backends.cudnn.allow_tf32 = allow_tf32, torch.cuda.reset_peak_memory_stats(), and all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device) as they have no MPS equivalent.
At this point I can generate images from the pretrained model, e.g.,
python train.py --outdir=../training-runs --data=../datasets/100-shot-obama.zip --gpus=1 --kimg 1
# /AppleInternal/Library/BuildRoots/20d6c351-ee94-11ec-bcaf-#7247572f23b4/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSNDArray/Kernels/MPSNDArrayConvolutionA14.mm:3237: failed assertion `destination datatype must be fp32'
Using pdb I can trace the error from .../DiffAugment-stylegan2-pytorch/training/loss.py(80)accumulate_gradients() -> loss_Gmain.mean().mul(gain).backward():80
totorch/autograd/graph.py(769)_engine_run_backward() -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass but I'm not really able to figure out what's going on. I have checked all tensors in the training loop are float-32 type.
Any suggestions would be appreciated! I don't have access to NVIDIA GPUs at the moment and the Colab also seems to be outdated.
The text was updated successfully, but these errors were encountered:
Hi,
I'm trying to run the PyTorch training implementation on an Apple M2 chip with MPS. I can run StyleGAN-ADA image generation following these steps but when I try to train DiffAugment I get this error:
My steps so far:
cd data-efficient-gans/DiffAugment-stylegan2-pytorch
conda create -n DiffAug python=3.9
conda activate DiffAug
conda install pytorch torchvision torchaudio -c pytorch
pip install click requests tqdm pyspng ninja imageio-ffmpeg==0.4.3
pip install Pillow psutil scipy
torch.device('cuda')
withtorch.device('mps')
random_array = np.random.RandomState(seed).randn(1, G.z_dim).astype(np.float32)
ingenerate.py
as described.training_loop.py
I replace instances oftorch.cuda.Event(enable_timing=True)
withtime.perf_counter()
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
,torch.backends.cudnn.allow_tf32 = allow_tf32
,torch.cuda.reset_peak_memory_stats()
, andall_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device)
as they have no MPS equivalent.At this point I can generate images from the pretrained model, e.g.,
but training aborts with the message below:
Using pdb I can trace the error from
.../DiffAugment-stylegan2-pytorch/training/loss.py(80)accumulate_gradients() -> loss_Gmain.mean().mul(gain).backward():80
to
torch/autograd/graph.py(769)_engine_run_backward() -> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
but I'm not really able to figure out what's going on. I have checked all tensors in the training loop are float-32 type.Any suggestions would be appreciated! I don't have access to NVIDIA GPUs at the moment and the Colab also seems to be outdated.
The text was updated successfully, but these errors were encountered: