欢迎光临
我们一直在努力

[mcj]pytorch基本函数理解【持续更新】

华为学生服务器每月仅需9元!

这篇博文主要讲pytorch的基本函数理解。

1 torch.arange(min,max,stride)

类似于python中的range

可以看到,这相当于从1~10,每隔3取一个数。

2 torch.normal(mean,std)

生成平均值为mean,标准差为std的随机数。例如

注意,在最新的pytorch1.x运行上面语句时,会出现错误:

这是因为数据类型不一致,改为:

即可!

这里0.2560是从均值为1,标准差为1的正态分布中随机生成的数,同理,3.9031是从均值为4,标准差为0.7的正态分布中随机生成的。

3 torch.cat(a,b,0/1)

这是用于拼接两个张量的函数,比如:

这里的尾缀0或者1代表按照什么进行拼接,如果是0就按照行进行拼接,如果是1就按照列进行拼接。

4 DataLoader()

这是用于将大量数据分为小batch的。比如一共1000张图片,我可以分为10个batch,每个batch size为100.

这个函数位于:torch.utils.data

这里的enumerate是python语法,类似于range,只不过它迭代的是一个可遍历的数据对象,比如列表、元组或者字符串。这里它会返回两个值,一个是索引,另一个是可遍历的对象。比如在这里,会返回索引(1,2),可遍历的对象(batch_x,batch_y),start=1意思是索引从1开始。

如果你对这篇文章有什么疑问或建议,欢迎下面留言提出,我看到会立刻回复!

打赏
未经允许不得转载:马春杰杰 » [mcj]pytorch基本函数理解【持续更新】
华为学生服务器每月仅需9元!

留个评论吧~ 抢沙发 评论前登陆可免验证码!

私密评论

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址(选填,便于回访^_^)
切换注册

登录

忘记密码 ?

切换登录

注册