Skip to content
This repository has been archived by the owner on Jun 21, 2024. It is now read-only.

Training with Hidet compiler #7

Open
RameshArvind opened this issue May 28, 2023 · 1 comment
Open

Training with Hidet compiler #7

RameshArvind opened this issue May 28, 2023 · 1 comment

Comments

@RameshArvind
Copy link

RameshArvind commented May 28, 2023

Hello!
I was wondering if there was anything extra that needed to be done to get training with Hidet compiler working.

Out of the box I seem to be running into errors

import torch
from palm_rlhf_pytorch import PaLM

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    flash_attn = True, # https://arxiv.org/abs/2205.14135
    cross_entropy_ignore_index = 0
).to(torch.bfloat16).cuda()

palm_opt = torch.compile(palm, backend='hidet')

seq = torch.randint(0, 20000, (1, 1024)).cuda()

loss = palm_opt(seq, return_loss = True)
loss.backward()

Some of the errors I faced were around the usage of rearrange here and here.

It also seems like einsum isn't supported. Even after replacing those OPs with equivalent alternatives. I'm still running into some reshape errors from hidet

AssertionError: , occurred when interpreting reshape with
  tensor_reshape(tensor(...), [1023])

I can post additional info as needed, but wondering if you ran into those same errors or if I'm doing something incorrectly.

Thanks!

@conceptofmind
Copy link
Owner

Hi @RameshArvind ,

To use torch.compile with einops for training you need to set:

from einops._torch_specific import allow_ops_in_compiled_graph  # requires einops>=0.6.1
allow_ops_in_compiled_graph()

I will have to do further investigation into the Hidet backend for training as I have only tested it out for inference.

I am also refactoring the code to provide a second training script and remove Huggingface and work directly with PyTorch FSDP/DDP.

I will work on integrating torch.compile there.

Best,

Enrico

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants