Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batch_idx rules for train_step #813

Open
linshokaku opened this issue Apr 25, 2024 · 0 comments
Open

batch_idx rules for train_step #813

linshokaku opened this issue Apr 25, 2024 · 0 comments
Assignees
Labels
cat:docs Improvements or additions to documentation

Comments

@linshokaku
Copy link
Member

In the docs in Logic.train_step, batch_idx is assumed to be the number of iterations executed, but in reality, it is the number of iterations in epoch.

def train_step(
self,
models: Mapping[str, torch.nn.Module],
optimizers: Mapping[str, torch.optim.Optimizer],
batch_idx: int,
batch: Any,
) -> Any:
"""A method invokes the model forward and backward passes.
Optimizing is left to `train_step_optimizers` since maybe the user
would like to aggregate the gradients of several iterations.
Args:
models (dict of torch.nn.Module):
The models.
optimizers (dict of torch.optim.Optimizer):
The optimizers.
batch_idx (int):
Number of training steps already finished.
batch (torch.Tensor, list of torch.Tensor, dict of torch.Tensor):
Input tensors feeded to the model of the current step.
"""

self.handler.train_step(
self,
idx,
x,
complete_fn=self._complete_step,
)

for idx in range(train_len):

@kmaehashi kmaehashi added the cat:docs Improvements or additions to documentation label Apr 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cat:docs Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants