马春杰杰博客
致力于深度学习经验分享!

PyTorch保存和加载模型方法汇总

文章目录
[隐藏]

看到一篇非常好的文章,介绍pytorch保存和加载模型的,非常详细。备份:

太长不看版:

方法1:

torch.save(modelA.state_dict(), PATH)

model.load_state_dict(torch.load(PATH), strict=False)

方法2:

torch.save(model, PATH)

model = torch.load(PATH)

0 前言

这篇博客主要是对使用PyTorch保存和加载训练模型参数的一个学习记录。第1-5小节是比较常规的模型参数保存操作,第6小是用已经训练好的模型参数来初始化新的模型,包括从一层加载到另一层,某些参数名不匹配的情况,也给出了实验代码和结果,完整实验项目见github。如果对您有所帮助,欢迎关注点赞~

1 state_dict

在PyTorch中,torch.nn.Module的可学习参数(i.e. weights and biases),保存在模型的parameters中,它可以通过model.parameters()进行访问。state_dict是一个从参数名称映射到参数Tensor的字典对象。注意,只有具有可学习参数的层(卷积层、线性层等)和已经注册的缓冲区(bachnorm’s running _mean)才有state_dict中的条目。优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。由于state_dic对象是Python字典,因此可以轻松地保存、更新、更改和还原它们,从而为PyTorch模型和优化增加了很多模块性。
从训练分类器教程中使用的简单模型看一下state_dict。

从输出结果可以看出,每一层的模型参数的名称格式是:层名.参数;如果有它的一层是由另一个类定义的话,那么就把层名往后扩展:层名.层名…参数。下面对上述代码的模型进行重新整理,验证一下。

模型重新整理的代码与结果:

参数名称输出。整个模型的网络结构还是一样的,但是将全连接层重新使用FC类来定义了。从输出的网络结构可以看出,TheModelClass类中定义时,使用的类的结构层次,在网络的结构中会体现。与(conv2)并列的(fc_the_model_class)是在TheModelClass类定义时用的变量名。后接的FC是fc_the_model_class使用的类名,后面的是这个类中定义的层。输出模型时,就是按一种深度优先的方法遍历了整个模型。对于更深层次层的参数,类名是不会出现在参数名中的,然后将参数名按深度组织:fc_the_model_class.fc.0.weight,也就是在打印过程中,:后面的名称会忽略。

2 保存和加载用于推理的模型参数

保存使用:

加载使用:

保存模型用于推理时,仅需要保存训练后的模型的参数,使用torch.save()函数直接保存模型的state_dict,通常文件的后缀名为.pt或.pth。请记住,在运行推理之前,必须先调用model.eval(),将dropout层和batch normalization层设为关闭状态。否则将会产生不一致的推断结果。
需要注意的是,load_state_dict()函数使用的是字典对象,而不是保存对象的路径,所以需要先进行torch.load(PATH)

3 保存和加载整个模型

保存使用:

加载使用:

4 保存和加载用于推理或者继续训练的general checkpoing

保存使用:

加载使用:

在保存checkpoint用于继续训练时,保存优化器的state_dict是必要的,因为它包含着随着模型训练而更新的缓冲区和参数,可能也需要保存一些其他的项目,包括epoch和loss。常见的PyTorch约定是使用.tar文件扩展名保存这些检查点。

5 将多个模型参数保存在一个文件中

保存使用 本质上还是保存的是一个字典对象,PyTorch约定使用.tar保存这些检查点。 

加载使用 加载还是加载的是字典对象,然后取字典对象。 :

6 使用来自不同模型的参数进行 Warmstarting Model ★ 

保存使用:

加载使用:对于不同的模型,设置strict=False是必要的。

在迁移学习或者训练新的复杂模型时,部分加载模型或加载部分模型是常见的方案。利用训练过的参数,即使是只有一小部分可以使用,也会对warmstart训练过程有所帮助,而且有望比从头开始训练模型更快地收敛。所谓的warmstart,我理解的就是在参数初始化时,将待训练模型的参数使用已经训练好的模型的部分参数进行初始化,然后接着训练,这种参数初始化方案会大大提高收敛的速度。
无论是从缺少某些键的部分state_dict加载,还是要加载比待加载模型更多的键的state_dic,都可以在lod_state_dict()中将strict参数设置为False,这样可以忽略不匹配的键。
如果要将参数从一层加载到另一层,但是某些键不匹配,只需要加载的state_dict中参数键的名称,来匹配到要加载到的模型中的键。 实验代码如下。

输出结果,可以看出fc_the_model_class.fc.4.bias的参数由随机初始化,变成从cifar_net模型中初始化。

参考资料

如果你对这篇文章有什么疑问或建议,欢迎下面留言提出,我看到会立刻回复!

打赏
未经允许不得转载:马春杰杰 » PyTorch保存和加载模型方法汇总
蝴蝶PT招人啦

留个评论吧~ 抢沙发 评论前登陆可免验证码!

私密评论
  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址(选填,便于回访^_^)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

微信扫一扫打赏

切换注册

登录

忘记密码 ?

切换登录

注册