公司新闻

Pytorch调整学习率的方法

调整学习率的重要性不言而喻,这是一个重要的炼丹参数,很魔幻很玄学,调不好轻则精度掉点,重则模型崩塌。可能同样的网络,有经验的老师傅训出来的就是比你的好,具体怎么调学习率这里不再过多赘述。常用的学习率调整策略有linear | step | plateau | cosine,下面谈一下Pytorch中调整学习率的方法

def adjust_learning_rate(epoch, opt):#定义一个学习率衰减策略函数,这是step
  """Sets the learning rate to the initial LR decayed by 10 every 10 epochs
     每10个epoch学习率除以10
  """
    lr = opt.lr * (0.1 ** (epoch // opt.step))
    return lr 
#定义一个优化器
optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum,weight_decay=opt.weight_decay)
for param_group in optimizer.param_groups:#在每次更新参数前迭代更改学习率 
    param_group["lr"] = lr 

官方提供的接口是lr_scheduler,导入方法如下
from torch.optim import lr_scheduler
(1)使用lr_scheduler内置的类包裹一下优化器
(2)scheduler.step()更新优化器的学习率

官方提供了一下学习率策略:linear | step | plateau | cosine
具体可见: pytorch.org/docs/stable


下面是来自Cycle-GAN里面写好的小Demo

def get_scheduler(optimizer, opt):
    """Return a learning rate scheduler
        Parameters:
        optimizer -- 网络优化器
        opt.lr_policy -- 学习率scheduler的名称: linear | step | plateau | cosine
    """
    if opt.lr_policy == 'linear':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
            return lr_l

    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
    elif opt.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    elif opt.lr_policy == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
    else:
        return NotImplementedError('learning rate policy[%s]is not implemented', opt.lr_policy)
    return scheduler

2.然后 scheduler.step() #调整学习率只需要step()就可以了
此处要注意,应该先optimizer.step(),然后再scheduler.step(),如果次序反了,pytorch也会提示的

可以看到核心代码还是for循环修改参数的学习率

def step(self, epoch=None):
 # Raise a warning if old pattern is detected
    # https://github.com/pytorch/pytorch/issues/20124
   if self._step_count==1:
        if not hasattr(self.optimizer.step, "_with_counter"):
    # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
    elif self.optimizer._step_count < 1:
        self._step_count +=1
    if epoch is None:
        epoch=self.last_epoch + 1
    self.last_epoch=epoch
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr']=lr

平台注册入口