PyTorch学习笔记07——模型的保存和加载
序列化与反序列化
模型的保存与加载也称序列化与反序列化
模型在内存中是以对象的形式存储的,而在硬盘中是以二进制序列保存的
序列化:是指将内存当中的某一个对象以二进制序列的形式存储到硬盘中,就可以长久的存储。
反序列化:将硬盘中的二进制数反序列化的放到内存中,得到对象,这样就可以使用模型了。
对应pytorch中的函数:
- torch.save
主要参数:
- obj:对象(模型、张量、parameters、dict 等等)
- f:输出路径(指定一个硬盘中的路径去保存)
模型保存有两种方法:
法1:保存整个Module
torch.save(net, path)
法2:保存模型参数
state_dict = net.state_dict()
torch.save(state_dict, path)
比如:
net = LeNet2(classes=2019)
# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])
path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"
# 保存整个模型
torch.save(net, path_model)
# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)
- torch.load
主要参数:
- f:文件路径(对应save中的f)
- map_location:指定存放位置,cpu or gpu(主要针对用gpu的时候)
torch.load(path)
比如:
# ================================== load net ===========================
# flag = 1
flag = 0
if flag:
path_model = "./model.pkl"
net_load = torch.load(path_model)
print(net_load)
# ================================== load state_dict ===========================
flag = 1
# flag = 0
if flag:
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)
print(state_dict_load.keys())
# ================================== update state_dict ===========================
flag = 1
# flag = 0
if flag:
net_new = LeNet2(classes=2019)
print("加载前: ", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)
print("加载后: ", net_new.features[0].weight[0, ...])
断点续训练
模型需要保存的最基本的数据:
checkpoint = {
"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch
}
断点保存:
checkpoint_interval = 5
for epoch in range(start_epoch+1, MAX_EPOCH):
...
# 训练代码
...
if (epoch+1) % checkpoint_interval == 0:
checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)
...
# 验证代码
...
续训练:
前面的数据、模型、损失函数、优化器不变,在训练之前添加断点恢复
# ============================ step 5+/5 断点恢复 ============================
# 加载checkpoint
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)
# 更新模型数据
net.load_state_dict(checkpoint['model_state_dict'])
# 更新优化器
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# 设置起始epoch
start_epoch = checkpoint['epoch']
# 学习率也需要更改last_epoch(上一次迭代次数)
scheduler.last_epoch = start_epoch
# ============================ step 5/5 训练 ============================
...
举个栗子理解
import torch
import torchvision.models as models
PyTorch 模型将学习到的参数存储在一个内部状态字典,称为state_dict。可以通过torch.save方法持久保存这些参数:
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
要加载模型权重,需要先创建同一模型的实例,然后使用load_state_dict()方法加载参数。
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
注意:请确保inferencing前调用model.eval()方法,以将 dropout 和 batch normalization 层设置为评估模式。否则将产生不一致的推理结果。
加载模型权重时,需要先实例化模型类,因为该类定义网络的结构。我们想保存这个类的结构连同模型,在这种情况下,我们可以传递model(而不是model.state_dict())给保存的函数:
torch.save(model, 'model.pth')
加载模型:
model = torch.load('model.pth')
注意:此方法在序列化模型时使用 Python pickle 模块,因此加载模型时它依赖于可用的实际类定义。
更细致的内容参见官方文档Saving and Loading a General Checkpoint in PyTorch。