马春杰杰博客
致力于深度学习经验分享!

[mcj]pytorch基本函数理解【持续更新】

这篇博文主要讲pytorch的基本函数理解。

1 torch.arange(min,max,stride)

类似于python中的range

可以看到,这相当于从1~10,每隔3取一个数。

2 torch.normal(mean,std)

生成平均值为mean,标准差为std的随机数。例如

注意,在最新的pytorch1.x运行上面语句时,会出现错误:

这是因为数据类型不一致,改为:

即可!

这里0.2560是从均值为1,标准差为1的正态分布中随机生成的数,同理,3.9031是从均值为4,标准差为0.7的正态分布中随机生成的。

3 torch.cat(a,b,0/1)

这是用于拼接两个张量的函数,比如:

这里的尾缀0或者1代表按照什么进行拼接,如果是0就按照行进行拼接,如果是1就按照列进行拼接。

4 DataLoader()

这是用于将大量数据分为小batch的。比如一共1000张图片,我可以分为10个batch,每个batch size为100.

这个函数位于:torch.utils.data

这里的enumerate是python语法,类似于range,只不过它迭代的是一个可遍历的数据对象,比如列表、元组或者字符串。这里它会返回两个值,一个是索引,另一个是可遍历的对象。比如在这里,会返回索引(1,2),可遍历的对象(batch_x,batch_y),start=1意思是索引从1开始。

赞(287) 打赏
版权声明:本文采用知识共享 署名4.0国际许可协议 [BY-NC-SA] 进行授权
文章名称:《[mcj]pytorch基本函数理解【持续更新】》
文章链接:https://www.machunjie.com/deeplearning/pytorch/109.html
本站资源仅供个人学习交流,请于下载后24小时内删除,不允许用于商业用途,否则法律问题自行承担。

评论 抢沙发

觉得文章有用就打赏一下文章作者

非常感谢你的打赏,我们将继续提供更多优质内容,让我们一起创建更加美好的网络世界!

支付宝扫一扫

微信扫一扫

:smile: :sad: :arrow: :cool: :confused: :cry: :eek: :evil: :exclaim: :idea: :lol: :mad: :mrgreen: :neutral: :question: :razz: :redface: :rolleyes: :surprised: :wink: :biggrin: :twisted: