看到一篇非常好的文章,介绍pytorch保存和加载模型的,非常详细。备份:
太长不看版:
方法1:
torch.save(modelA.state_dict(), PATH)
model.load_state_dict(torch.load(PATH), strict=False)
方法2:
torch.save(model, PATH)
model = torch.load(PATH)
0 前言
这篇博客主要是对使用PyTorch保存和加载训练模型参数的一个学习记录。第1-5小节是比较常规的模型参数保存操作,第6小是用已经训练好的模型参数来初始化新的模型,包括从一层加载到另一层,某些参数名不匹配的情况,也给出了实验代码和结果
,完整实验项目见github。如果对您有所帮助,欢迎关注点赞~
1 state_dict
在PyTorch中,torch.nn.Module的可学习参数(i.e. weights and biases),保存在模型的parameters中,它可以通过model.parameters()进行访问。state_dict是一个从参数名称映射到参数Tensor的字典对象。注意,只有具有可学习参数的层(卷积层、线性层等)和已经注册的缓冲区(bachnorm’s running _mean)才有state_dict中的条目。优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。由于state_dic对象是Python字典,因此可以轻松地保存、更新、更改和还原它们,从而为PyTorch模型和优化增加了很多模块性。
从训练分类器教程中使用的简单模型看一下state_dict。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
# Define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # Initialize model model = TheModelClass() # Initialize optimizer optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Print model's state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor, "\t", model.state_dict()[param_tensor].size()) # Print optimizer's state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name]) |
从输出结果可以看出,每一层的模型参数的名称格式是:层名.参数;如果有它的一层是由另一个类定义的话,那么就把层名往后扩展:层名.层名…参数。下面对上述代码的模型进行重新整理,验证一下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
TheModelClass( (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) (fc1): Linear(in_features=400, out_features=120, bias=True) (fc2): Linear(in_features=120, out_features=84, bias=True) (fc3): Linear(in_features=84, out_features=10, bias=True) ) Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) |
模型重新整理的代码与结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
class TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc_the_model_class = FC() def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = self.fc_the_model_class(x) return x class FC(nn.Module): def __init__(self): super(FC, self).__init__() self.fc = nn.Sequential( nn.Linear(16 * 5 * 5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, 10) ) def forward(self, x): return self.fc(x) # Initialize model model = TheModelClass() print(model) # Initialize optimizer optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Print model's state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor, "\t", model.state_dict()[param_tensor].size()) |
参数名称输出。整个模型的网络结构还是一样的,但是将全连接层重新使用FC类来定义了。从输出的网络结构可以看出,TheModelClass类中定义时,使用的类的结构层次,在网络的结构中会体现。与(conv2)并列的(fc_the_model_class)是在TheModelClass类定义时用的变量名。后接的FC是fc_the_model_class使用的类名,后面的是这个类中定义的层。输出模型时,就是按一种深度优先的方法遍历了整个模型。对于更深层次层的参数,类名是不会出现在参数名中的,然后将参数名按深度组织:fc_the_model_class.fc.0.weight,也就是在打印过程中,:后面的名称会忽略。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
TheModelClass( (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) (fc_the_model_class): FC( (fc): Sequential( (0): Linear(in_features=400, out_features=120, bias=True) (1): ReLU() (2): Linear(in_features=120, out_features=84, bias=True) (3): ReLU() (4): Linear(in_features=84, out_features=10, bias=True) ) ) ) Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc_the_model_class.fc.0.weight torch.Size([120, 400]) fc_the_model_class.fc.0.bias torch.Size([120]) fc_the_model_class.fc.2.weight torch.Size([84, 120]) fc_the_model_class.fc.2.bias torch.Size([84]) fc_the_model_class.fc.4.weight torch.Size([10, 84]) fc_the_model_class.fc.4.bias torch.Size([10]) |
2 保存和加载用于推理的模型参数
保存使用:
1 |
torch.save(model.state_dict(), PATH) |
加载使用:
1 2 3 |
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval() |
保存模型用于推理时,仅需要保存训练后的模型的参数,使用torch.save()函数直接保存模型的state_dict,通常文件的后缀名为.pt或.pth。请记住,在运行推理之前,必须先调用model.eval(),将dropout层和batch normalization层设为关闭状态。否则将会产生不一致的推断结果。
需要注意的是,load_state_dict()函数使用的是字典对象,而不是保存对象的路径,所以需要先进行torch.load(PATH)
3 保存和加载整个模型
保存使用:
1 |
torch.save(model, PATH) |
加载使用:
1 2 3 |
# Model class must be defined somewhere model = torch.load(PATH) model.eval() |
4 保存和加载用于推理或者继续训练的general checkpoing
保存使用:
1 2 3 4 5 6 7 |
torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, ... }, PATH) |
加载使用:
1 2 3 4 5 6 7 8 9 10 11 12 |
model = TheModelClass(*args, **kwargs) optimizer = TheOptimizerClass(*args, **kwargs) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - or - model.train() |
在保存checkpoint用于继续训练时,保存优化器的state_dict是必要的,因为它包含着随着模型训练而更新的缓冲区和参数,可能也需要保存一些其他的项目,包括epoch和loss。常见的PyTorch约定是使用.tar文件扩展名保存这些检查点。
5 将多个模型参数保存在一个文件中
保存使用 本质上还是保存的是一个字典对象,PyTorch约定使用.tar保存这些检查点。 :
1 2 3 4 5 6 7 |
torch.save({ 'modelA_state_dict': modelA.state_dict(), 'modelB_state_dict': modelB.state_dict(), 'optimizerA_state_dict': optimizerA.state_dict(), 'optimizerB_state_dict': optimizerB.state_dict(), ... }, PATH) |
加载使用 加载还是加载的是字典对象,然后取字典对象。 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
modelA = TheModelAClass(*args, **kwargs) modelB = TheModelBClass(*args, **kwargs) optimizerA = TheOptimizerAClass(*args, **kwargs) optimizerB = TheOptimizerBClass(*args, **kwargs) checkpoint = torch.load(PATH) modelA.load_state_dict(checkpoint['modelA_state_dict']) modelB.load_state_dict(checkpoint['modelB_state_dict']) optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']) optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']) modelA.eval() modelB.eval() # - or - modelA.train() modelB.train() |
6 使用来自不同模型的参数进行 Warmstarting Model ★
保存使用:
1 |
torch.save(modelA.state_dict(), PATH) |
加载使用:对于不同的模型,设置strict=False是必要的。
1 2 |
modelB = TheModelBClass(*args, **kwargs) modelB.load_state_dict(torch.load(PATH), strict=False) |
在迁移学习或者训练新的复杂模型时,部分加载模型或加载部分模型是常见的方案。利用训练过的参数,即使是只有一小部分可以使用,也会对warmstart训练过程有所帮助,而且有望比从头开始训练模型更快地收敛。所谓的warmstart,我理解的就是在参数初始化时,将待训练模型的参数使用已经训练好的模型的部分参数进行初始化,然后接着训练,这种参数初始化方案会大大提高收敛的速度。
无论是从缺少某些键的部分state_dict加载,还是要加载比待加载模型更多的键的state_dic,都可以在lod_state_dict()中将strict参数设置为False,这样可以忽略不匹配的键。
如果要将参数从一层加载到另一层,但是某些键不匹配,只需要加载的state_dict中参数键的名称,来匹配到要加载到的模型中的键。 实验代码如下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
targetModel = TheModelClass() cifar_net = torch.load('./cifar_net.pth') for item in cifar_net: print('cifar_net \t', item, '\t') targetModel.load_state_dict(cifar_net, strict=False) for item in targetModel.state_dict(): print('targetModel \t', item, '\t') print('cifar_net \t', cifar_net["fc3.bias"], '\t', cifar_net["fc3.bias"].data) print('targetModel \t', targetModel.state_dict()["fc_the_model_class.fc.4.bias"], '\t', targetModel.state_dict()["fc_the_model_class.fc.4.bias"].data) # 更新层的名称 cifar_net["fc_the_model_class.fc.0.weight"] = cifar_net.pop("fc1.weight") cifar_net["fc_the_model_class.fc.0.bias"] = cifar_net.pop("fc1.bias") cifar_net["fc_the_model_class.fc.2.weight"] = cifar_net.pop("fc2.weight") cifar_net["fc_the_model_class.fc.2.bias"] = cifar_net.pop("fc2.bias") cifar_net["fc_the_model_class.fc.4.weight"] = cifar_net.pop("fc3.weight") cifar_net["fc_the_model_class.fc.4.bias"] = cifar_net.pop("fc3.bias") targetModel.load_state_dict(cifar_net, strict=False) print('cifar_net \t', cifar_net["fc_the_model_class.fc.4.bias"], '\t', cifar_net["fc_the_model_class.fc.4.bias"].data) print('targetModel \t', targetModel.state_dict()["fc_the_model_class.fc.4.bias"], '\t', targetModel.state_dict()["fc_the_model_class.fc.4.bias"].data) |
输出结果,可以看出fc_the_model_class.fc.4.bias的参数由随机初始化,变成从cifar_net模型中初始化。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
cifar_net tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04, 4.6838e-01, 1.1072e+00, -2.2960e-01, 1.9044e-01, -5.1352e-02, 1.8365e-01, -3.4669e-01], device='cuda:0') tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04, 4.6838e-01, 1.1072e+00, -2.2960e-01, 1.9044e-01, -5.1352e-02, 1.8365e-01, -3.4669e-01], device='cuda:0') targetModel tensor([-0.0878, -0.1059, -0.0949, 0.0353, 0.0164, -0.1002, -0.0126, -0.1012, -0.0115, -0.1006]) tensor([-0.0878, -0.1059, -0.0949, 0.0353, 0.0164, -0.1002, -0.0126, -0.1012, -0.0115, -0.1006]) cifar_net tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04, 4.6838e-01, 1.1072e+00, -2.2960e-01, 1.9044e-01, -5.1352e-02, 1.8365e-01, -3.4669e-01], device='cuda:0') tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04, 4.6838e-01, 1.1072e+00, -2.2960e-01, 1.9044e-01, -5.1352e-02, 1.8365e-01, -3.4669e-01], device='cuda:0') targetModel tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04, 4.6838e-01, 1.1072e+00, -2.2960e-01, 1.9044e-01, -5.1352e-02, 1.8365e-01, -3.4669e-01]) tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04, 4.6838e-01, 1.1072e+00, -2.2960e-01, 1.9044e-01, -5.1352e-02, 1.8365e-01, -3.4669e-01]) |
参考资料
- pytorch 加载使用部分预训练模型(pretrained model),这篇博客里的功能可以直接由设置strict=True来完成。
- pytorch文档:SAVING AND LOADING MODELS
- Dive-into-DL-PyTorch
- python 字典修改键(key)的方法。
最新评论
大佬 http://ffflyy22.xyz/apps/fly.html 这是我试着创的站 但苹果手机安装后 执行都卡死 想请问是哪边出问题 可以的话 我提供网页编码寄给您 跪求协助…
马大佬您好 自己弄了一个苹果站 但测试安装不能用 连夏时也没法 能否与您交流下呢 我邮箱 ghostchat1@protonmail.com 跪求大佬T.T
该评论为私密评论
该评论为私密评论
该评论为私密评论
大佬 误传了一个文件 可以帮忙删除一下吗
aaaaa 我
我的也是