PyTorch学习笔记07——模型的保存和加载

序列化与反序列化

模型的保存与加载也称序列化与反序列化
模型在内存中是以对象的形式存储的,而在硬盘中是以二进制序列保存的
序列化:是指将内存当中的某一个对象以二进制序列的形式存储到硬盘中,就可以长久的存储。
反序列化:将硬盘中的二进制数反序列化的放到内存中,得到对象,这样就可以使用模型了。

对应pytorch中的函数:

  1. torch.save
    主要参数:
  • obj:对象(模型、张量、parameters、dict 等等)
  • f:输出路径(指定一个硬盘中的路径去保存)

模型保存有两种方法:
法1:保存整个Module

torch.save(net, path)

法2:保存模型参数

state_dict = net.state_dict()
torch.save(state_dict, path)

比如:

net = LeNet2(classes=2019)

# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])

path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model)

# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)
  1. torch.load
    主要参数:
  • f:文件路径(对应save中的f)
  • map_location:指定存放位置,cpu or gpu(主要针对用gpu的时候)
torch.load(path)

比如:

# ================================== load net ===========================
# flag = 1
flag = 0
if flag:

    path_model = "./model.pkl"
    net_load = torch.load(path_model)

    print(net_load)

# ================================== load state_dict ===========================

flag = 1
# flag = 0
if flag:

    path_state_dict = "./model_state_dict.pkl"
    state_dict_load = torch.load(path_state_dict)

    print(state_dict_load.keys())

# ================================== update state_dict ===========================
flag = 1
# flag = 0
if flag:

    net_new = LeNet2(classes=2019)

    print("加载前: ", net_new.features[0].weight[0, ...])
    net_new.load_state_dict(state_dict_load)
    print("加载后: ", net_new.features[0].weight[0, ...])

断点续训练

模型需要保存的最基本的数据:

checkpoint = {
			"model_state_dict": net.state_dict(),
			"optimizer_state_dict": optimizer.state_dict(),
			"epoch": epoch
}

断点保存:

checkpoint_interval = 5
for epoch in range(start_epoch+1, MAX_EPOCH):
	...
	# 训练代码
	...
    if (epoch+1) % checkpoint_interval == 0:

        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)
    ...
    # 验证代码
    ...

续训练:
前面的数据、模型、损失函数、优化器不变,在训练之前添加断点恢复

# ============================ step 5+/5 断点恢复 ============================
# 加载checkpoint
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)
# 更新模型数据
net.load_state_dict(checkpoint['model_state_dict'])
# 更新优化器
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# 设置起始epoch
start_epoch = checkpoint['epoch']
# 学习率也需要更改last_epoch(上一次迭代次数)
scheduler.last_epoch = start_epoch

# ============================ step 5/5 训练 ============================
...

举个栗子理解

import torch
import torchvision.models as models

PyTorch 模型将学习到的参数存储在一个内部状态字典,称为state_dict。可以通过torch.save方法持久保存这些参数:

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

要加载模型权重,需要先创建同一模型的实例,然后使用load_state_dict()方法加载参数。

model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

注意:请确保inferencing前调用model.eval()方法,以将 dropout 和 batch normalization 层设置为评估模式。否则将产生不一致的推理结果。

加载模型权重时,需要先实例化模型类,因为该类定义网络的结构。我们想保存这个类的结构连同模型,在这种情况下,我们可以传递model(而不是model.state_dict())给保存的函数:

torch.save(model, 'model.pth')

加载模型:

model = torch.load('model.pth')

注意:此方法在序列化模型时使用 Python pickle 模块,因此加载模型时它依赖于可用的实际类定义。

更细致的内容参见官方文档Saving and Loading a General Checkpoint in PyTorch