PyTorch implement of Lookahead Optimizer: k steps forward, 1 step back
Usage:
base_opt = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999)) # Any optimizer
lookahead = Lookahead(base_opt, k=5, alpha=0.5) # Initialize Lookahead
lookahead.zero_grad()
loss_function(model(input), target).backward() # Self-defined loss function
lookahead.step()
论文《Lookahead Optimizer: k steps forward, 1 step back》的PyTorch实现。
用法:
base_opt = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999)) # 用你想用的优化器
lookahead = Lookahead(base_opt, k=5, alpha=0.5) # 初始化Lookahead
lookahead.zero_grad()
loss_function(model(input), target).backward() # 自定义的损失函数
lookahead.step()