Custom training loop for LightningModule #6456
-
Hello, I was wondering if it is possible to control the trainloop behavior of a module (beyond overriding For example, let's say I have this routine: m_0 = MyModel()
loader_1 = getTrainLoader(1)
loader_2 = getTrainLoader(2)
loader_3 = getTrainLoader(3)
# train the first two models
m_1 = train_model_for_one_epoch(m_0, loader_1)
m_2 = train_model_for_one_epoch(m_1, loader_2)
# train the third model based on the previous models
m_3 = MyModel()
criteriton = nn.CrossEntropyLoss()
optimizer = optim.SGD(m_3.parameters(), lr)
# main trainloop
for data, target in loader_3:
loss_1 = criteriton(m_1(data), target)
loss_1.backward()
grad_1 = get_gradient_vector(m_1)
loss_2 = criterion(m_2(data), target)
loss_2.backward()
grad_2 = get_gradient_vector(m_2)
# manually calculate & set gradient
grad_3 = (grad_1 + grad_2) / 2.0
set_model_gradient(m_3, grad_3)
optimizer.step()
How can I implement the final loop in the above code in PL? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Currently there's no easy way for users to manage the dataloaders themselves, but you can perform the optimization (and manipulate the gradients) by setting see: https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#manual-optimization |
Beta Was this translation helpful? Give feedback.
Currently there's no easy way for users to manage the dataloaders themselves, but you can perform the optimization (and manipulate the gradients) by setting
automatic_optimization=False
see: https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#manual-optimization