调整学习率的重要性不言而喻,这是一个重要的炼丹参数,很魔幻很玄学,调不好轻则精度掉点,重则模型崩塌。可能同样的网络,有经验的老师傅训出来的就是比你的好,具体怎么调学习率这里不再过多赘述。常用的学习率调整策略有linear | step | plateau | cosine,下面谈一下Pytorch中调整学习率的方法
def adjust_learning_rate(epoch, opt):#定义一个学习率衰减策略函数,这是step
"""Sets the learning rate to the initial LR decayed by 10 every 10 epochs
每10个epoch学习率除以10
"""
lr = opt.lr * (0.1 ** (epoch // opt.step))
return lr
#定义一个优化器
optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum,weight_decay=opt.weight_decay)
for param_group in optimizer.param_groups:#在每次更新参数前迭代更改学习率
param_group["lr"] = lr
官方提供的接口是lr_scheduler,导入方法如下 from torch.optim import lr_scheduler
(1)使用lr_scheduler内置的类包裹一下优化器
(2)scheduler.step()更新优化器的学习率
官方提供了一下学习率策略:linear | step | plateau | cosine
具体可见: https://pytorch.org/docs/stable/optim.html
下面是来自Cycle-GAN里面写好的小Demo
def get_scheduler(optimizer, opt):
"""Return a learning rate scheduler
Parameters:
optimizer -- 网络优化器
opt.lr_policy -- 学习率scheduler的名称: linear | step | plateau | cosine
"""
if opt.lr_policy == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
else:
return NotImplementedError('learning rate policy[%s]is not implemented', opt.lr_policy)
return scheduler
2.然后 scheduler.step() #调整学习率只需要step()就可以了
此处要注意,应该先optimizer.step(),然后再scheduler.step(),如果次序反了,pytorch也会提示的
可以看到核心代码还是for循环修改参数的学习率
def step(self, epoch=None):
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count==1:
if not hasattr(self.optimizer.step, "_with_counter"):
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif self.optimizer._step_count < 1:
self._step_count +=1
if epoch is None:
epoch=self.last_epoch + 1
self.last_epoch=epoch
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr']=lr
公司名称: 开丰娱乐-开丰五金配件机电公司
手 机: 13800000000
电 话: 400-123-4567
邮 箱: admin@youweb.com
地 址: 广东省广州市天河区88号