引入ResNet模型来训练自己的数据集
我们在做科研时,常常需要做实验。为此,参考了网上很多的教程,综合各自的写下以下内容。一方面为自己留点笔记,日后好学习。另一方面为各位朋友提供一定的参考。
以卷积神经网络的resnet网络模型为例,简要的说明如何引入模型或者修改模型来做实验。
1、引入模型
model_ft = models.resnet50(pretrained=True)
model_ft = model_ft.to(device)
这里只是引入模型,没有做任何修改。原模型中类别数为1000,因此我们在训练自己的数据集时,需要修改种类数。
2、修改自己数据集的类别
model_ft = models.resnet50(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, n) # 这里的n为类别数
model_ft = model_ft.to(device)
3、如果修改resnet的网络结构,比如加入注意力机制
model_ft = models.resnet50(pretrained=False)
net_dict = model_ft.state_dict()
predict_model = torch.load('resnet50-5c106cde.pth')
# 寻找网络中公共层,并保留预训练参数
state_dict = {k: v for k, v in predict_model.items() if k in net_dict.keys()}
net_dict.update(state_dict) # 将预训练参数更新到新的网络层
model_ft.load_state_dict(net_dict)
# 修改最后一层全连接层的数量,改为分类种类的数量
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, n)
model_ft = model_ft.to(device)
4、再附上resnet其它的模型结构
import torch
import torch.nn as nn
#__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
# 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
# 'wide_resnet50_2', 'wide_resnet101_2']
#
#
#model_urls = {
# 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
# 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
# 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
# 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
# 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
# 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
# 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
# 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
# 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
#}