前言
复习一下基础工程问题。
torch.optim.LRScheduler
LRScheduler是一个模板类。具体的schedluer类型需要继承它。
每种LRScheduler的具体实现,都需要提供 get_lr()
无参方法的implementation。
所有需要的参数通过self传递。
默认情况使用无参 step()
进行 lr 的更新。(带参版本deprecated)。
简化逻辑如下:
def step(self, epoch=None):
if epoch is None:
self.last_epoch += 1
values = self.get_lr()
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group['lr'] = lr
self.print_lr(self.verbose, i, lr, epoch)
简单说就是通过get_lr()
拿到长度等于num_groups的list of values。
对group的管理
只使用optimizer的时候每个group只关心自己的group['lr']
。
在我们引入LRScheduler后,
会在init阶段,
默认情况下(即不给出last_epoch
时),
将group['lr']
移植给 group['initial_lr']
。
而后scheduler 在 group['initial_lr']
的基础上进行各种 lr
操作。
通常是linear变化。
last_epoch的心智模型
torch api存在一种隐形约定。
注意到step()
时 先有 self.last_epoch += 1
,再有 values = self.get_lr()
。
所以我们实现 get_lr()
时可以根据 self.last_epoch
去定位当前需要的epoch。
这样一来似乎有点奇怪,明明命名是last,使用时却变成了current。
所以需要对齐一下心智模型。
事实上,LRScheduler类在__init__()
时会call一次_initial_step()
,
这时候就会step()
一次。
即初始化阶段,就会让 self.last_epoch
从-1变成0。
然后根据self.last_epoch==0
计算一次get_lr
。
所以last_epoch == 0
的意思并非是我们已经完成了epoch0,即将开始epoch1。
而是指,已经完成的epoch的计数为0,
我们即将开始0-th epoch
。
即事实上地,这个参数应该作为current_epoch
来运用。
但torch不知道为什么选了last
这个名字。
作为例证,我们来看一下其他库的实现。
例如大名鼎鼎的 timm
。
# from timm.scheduler.cosine_lr
class CosineLRLambda:
def __init__(self, scheduler_params):
self.warmup_epochs = scheduler_params["warmup_epochs"]
self.lr_warmup_factor = scheduler_params["warmup_factor"]
self.max_epochs = scheduler_params["epochs"]
self.lr_min_factor = scheduler_params["lr_min_factor"]
def __call__(self, current_step):
# `warmup_epochs` is already multiplied with the num of iterations
if current_step < self.warmup_epochs:
alpha = current_step / float(self.warmup_epochs)
return self.lr_warmup_factor * (1.0 - alpha) + alpha
else:
if current_step >= self.max_epochs:
return self.lr_min_factor
lr_scale = self.lr_min_factor + 0.5 * (1 - self.lr_min_factor) * (
1 + math.cos(math.pi * (current_step / self.max_epochs))
)
return lr_scale
注意到timm
的上述实现中,唯一入参为 current_step
。
这样第一次调用时,父类torch.optim.LambdaLR
会在get_lr()
中调用lmbda(self.last_epoch)
。
这个self.last_epoch==0
传入到 CosineLRLambda.__call__()
中就变成了current_step == 0
。
证明 last_epoch 确实是被作为 current用的。
小结
在这种心智模型下,
torch的隐形假设是,
我们总是在一个epoch的结尾才调用scheduler.step()
。
因为第0epoch一开始的那个step() 放在 init() 里面完成了。
下一个step()应该在0epoch的结尾处产生。
于是我们开发者在get_lr()
的实现中,
应该把self.last_epoch
参数视作 current_epoch
的含义来使用。
或者更严谨地说,应该是next_epoch
来使用,因为总是先 self.last_epoch +=1
再 get_lr()
。
选用last应该是一个torch的命名失误。
保存state_dict的顺序问题
考虑到上述last_epoch的心智模型。
在每个epoch结束后,有2种做法。
-
应该先进行
scheduler.step()
,切换到next_epoch
andnext lr
,再保存scheduler.state_dict()
。这样一来,下次直接scheduler.load_state_dict()
就能完成next_epoch
的preparation。 -
先
save(scheduler.state_dict())
,再scheduler.step()
。 如果这样做的话,恢复训练时,在scheduler.load_state_dict()
后,需要手动再scheduler.step()
一次,进入到next_epoch
。
第三种就是那种无状态的有closed form的 lrscheduler,step时只依赖epoch序号的。
这种只需要记录save时的epoch,然后重新赋值给 lrscheduler ,再step()一次进入next。
样例代码:
ckp_dict = load('...somefile')
epoch = ckp_dict['epoch']
lr_scheduler.last_epoch = epoch
lr_scheduler.step() # 进入next epoch
如果是有状态的lr_scheduler,例如ReduceLROnPlateau
这种需要依赖metrics
更新的,而不仅仅依赖epoch序号的。
你就得再读save时的metrics手动恢复。
感觉挺麻烦。
最简单的办法一定是save时,先step再save。