Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code repetition in train methods #922

Open
3 tasks
janfb opened this issue Jan 24, 2024 · 1 comment
Open
3 tasks

code repetition in train methods #922

janfb opened this issue Jan 24, 2024 · 1 comment
Assignees
Labels
architecture Internal changes without API consequences enhancement New feature or request hackathon

Comments

@janfb
Copy link
Contributor

janfb commented Jan 24, 2024

Description:

The current implementation of the SBI library contains significant code duplication within the train(...) methods of SNPE, SNRE, and SNLE. These methods share many common functionalities, including:

  • Building the neural network
  • Resuming training
  • Managing the training and validation loops

This redundancy increases the complexity of the codebase, making it harder to maintain and more prone to inconsistencies and bugs, particularly during updates or enhancements.

To address this, we propose refactoring these methods by introducing a unified train function in the base class. This common train function would handle the shared aspects of the training process, while accepting specific losses and other relevant keyword arguments as parameters to handle the differences between SNPE, SNRE, and SNLE.

Example redundancies

  • SNPE:
    while self.epoch <= max_num_epochs and not self._converged(
    self.epoch, stop_after_epochs
    ):
    # Train for a single epoch.
    self._neural_net.train()
    train_log_probs_sum = 0
    epoch_start_time = time.time()
    for batch in train_loader:
    self.optimizer.zero_grad()
    # Get batches on current device.
    theta_batch, x_batch, masks_batch = (
    batch[0].to(self._device),
    batch[1].to(self._device),
    batch[2].to(self._device),
    )
    train_losses = self._loss(
    theta_batch,
    x_batch,
    masks_batch,
    proposal,
    calibration_kernel,
    force_first_round_loss=force_first_round_loss,
    )
    train_loss = torch.mean(train_losses)
    train_log_probs_sum -= train_losses.sum().item()
    train_loss.backward()
    if clip_max_norm is not None:
    clip_grad_norm_(
    self._neural_net.parameters(), max_norm=clip_max_norm
    )
    self.optimizer.step()
    self.epoch += 1
    train_log_prob_average = train_log_probs_sum / (
    len(train_loader) * train_loader.batch_size # type: ignore
    )
    self._summary["training_log_probs"].append(train_log_prob_average)
  • SNLE:
    while self.epoch <= max_num_epochs and not self._converged(
    self.epoch, stop_after_epochs
    ):
    # Train for a single epoch.
    self._neural_net.train()
    train_log_probs_sum = 0
    for batch in train_loader:
    self.optimizer.zero_grad()
    theta_batch, x_batch = (
    batch[0].to(self._device),
    batch[1].to(self._device),
    )
    # Evaluate on x with theta as context.
    train_losses = self._loss(theta=theta_batch, x=x_batch)
    train_loss = torch.mean(train_losses)
    train_log_probs_sum -= train_losses.sum().item()
    train_loss.backward()
    if clip_max_norm is not None:
    clip_grad_norm_(
    self._neural_net.parameters(),
    max_norm=clip_max_norm,
    )
    self.optimizer.step()
    self.epoch += 1
    train_log_prob_average = train_log_probs_sum / (
    len(train_loader) * train_loader.batch_size # type: ignore
    )
    self._summary["training_log_probs"].append(train_log_prob_average)
  • SNRE:
    while self.epoch <= max_num_epochs and not self._converged(
    self.epoch, stop_after_epochs
    ):
    # Train for a single epoch.
    self._neural_net.train()
    train_log_probs_sum = 0
    for batch in train_loader:
    self.optimizer.zero_grad()
    theta_batch, x_batch = (
    batch[0].to(self._device),
    batch[1].to(self._device),
    )
    train_losses = self._loss(
    theta_batch, x_batch, num_atoms, **loss_kwargs
    )
    train_loss = torch.mean(train_losses)
    train_log_probs_sum -= train_losses.sum().item()
    train_loss.backward()
    if clip_max_norm is not None:
    clip_grad_norm_(
    self._neural_net.parameters(),
    max_norm=clip_max_norm,
    )
    self.optimizer.step()
    self.epoch += 1
    train_log_prob_average = train_log_probs_sum / (
    len(train_loader) * train_loader.batch_size # type: ignore
    )
    self._summary["training_log_probs"].append(train_log_prob_average)

Proposed Steps

  • Identify and abstract the common code segments across the train methods of SNPE, SNRE, and SNLE.
  • Design a generic train function in the base class that accepts specific losses and other necessary arguments unique to each method. Parts shared by some, but not all methods, should be offloaded into separate class methods that can be overridden by children's classes if required.
  • Refactor the existing train methods to utilize the new generic function, passing their specific requirements as arguments.

We encourage contributors to discuss strategies for this refactoring and help with the implementation. This effort will improve the library’s maintainability and ensure consistency across its components.

If you identify other areas where significant code duplication can be reduced, please create a new issue (e.g., #921).

@janfb janfb added enhancement New feature or request architecture Internal changes without API consequences hackathon labels Jan 24, 2024
@janfb janfb added this to the Pre Hackathon 2024 milestone Feb 6, 2024
@janfb janfb self-assigned this Feb 16, 2024
@janfb janfb removed the hackathon label Jul 22, 2024
@janfb
Copy link
Contributor Author

janfb commented Jul 22, 2024

This will become even more relevant when we have a common dataloader interface and agnostic loss functions for all SBI methods. But I am removing the hackathon label for now as it will not be done before the release.

@janfb janfb removed this from the Hackathon 2024 milestone Jul 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
architecture Internal changes without API consequences enhancement New feature or request hackathon
Projects
None yet
Development

No branches or pull requests

2 participants