Skip to content

PyTorch Implementation of the paper titled "Learning with Retrospection"

Notifications You must be signed in to change notification settings

The-Learning-Machines/LearningWithRetrospection

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

☢️ Learning with Retrospection ☢️

About

  • LearningWithRetrospection.py contains class LWR which implements the algorithm in this paper

Algorithm

Usage

    lwr = LWR(
        k=1, # Number of Epochs (Interval) to update soft labels
        update_rate=0.9, # The rate at which True Label weightage is decayed
        num_batches_per_epoch=len(dataset) // batch_size,
        dataset_length=len(dataset),
        output_shape=(10, ), # Number of Classes
        tau=5, # Temperature -- Just leave it to 5 if you don't know what you're doing
        max_epochs=20, # Max number of epochs
        softmax_dim=1 # Axis for softmax
    )

    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = lwr(batch_idx, output, target, eval=False) # LWR expects LOGITS
    loss.backward()
    optimizer.step()

About

PyTorch Implementation of the paper titled "Learning with Retrospection"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages