马春杰杰 Exit Reader Mode

[mcj]pytorch基本函数理解【持续更新】

这篇博文主要讲pytorch的基本函数理解。

1 torch.arange(min,max,stride)

类似于python中的range

print(torch.arange(1,10,3))
> tensor([1, 4, 7])

可以看到,这相当于从1~10,每隔3取一个数。

2 torch.normal(mean,std)

生成平均值为mean,标准差为std的随机数。例如

torch.normal(means=torch.arange(1, 10 ,3), std=torch.arange(1, 0, -0.3))

注意,在最新的pytorch1.x运行上面语句时,会出现错误:

TypeError      Traceback (most recent call last)
<ipython-input-74-22573061bc42> in <module>      
1 torch.cuda.is_available()      
2 ----> 
3 torch.normal(torch.arange(1., 11.,3.),torch.arange(1, 0, -0.3))TypeError: normal() received an invalid combination of arguments - got (std=Tensor, means=Tensor, ), but expected one of:
 * (Tensor mean, Tensor std, torch.Generator generator, Tensor out)
 * (Tensor mean, float std, torch.Generator generator, Tensor out)
 * (float mean, Tensor std, torch.Generator generator, Tensor out

这是因为数据类型不一致,改为:

torch.normal(torch.arange(1., 11.,3.),torch.arange(1, 0, -0.3))
> tensor([0.2560, 3.9031, 6.1879, 9.8848])

即可!

这里0.2560是从均值为1,标准差为1的正态分布中随机生成的数,同理,3.9031是从均值为4,标准差为0.7的正态分布中随机生成的。

3 torch.cat(a,b,0/1)

这是用于拼接两个张量的函数,比如:

>>> import torch
>>> A=torch.ones(2,3) #2x3的张量(矩阵)                                     
>>> A
tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.]])
>>> B=2*torch.ones(4,3)#4x3的张量(矩阵)                                    
>>> B
tensor([[ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.]])
>>> C=torch.cat((A,B),0)#按维数0(行)拼接
>>> C
tensor([[ 1.,  1.,  1.],
         [ 1.,  1.,  1.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.]])
>>> C.size()
torch.Size([6, 3])
>>> D=2*torch.ones(2,4) #2x4的张量(矩阵)
>>> C=torch.cat((A,D),1)#按维数1(列)拼接
>>> C
tensor([[ 1.,  1.,  1.,  2.,  2.,  2.,  2.],
        [ 1.,  1.,  1.,  2.,  2.,  2.,  2.]])
>>> C.size()
torch.Size([2, 7])

这里的尾缀0或者1代表按照什么进行拼接,如果是0就按照行进行拼接,如果是1就按照列进行拼接。

4 DataLoader()

这是用于将大量数据分为小batch的。比如一共1000张图片,我可以分为10个batch,每个batch size为100.

这个函数位于:torch.utils.data

import torch
import torch.utils.data as Data
BATCH_SIZE = 8
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

torch_dataset = Data.TensorDataset(x,y)
loader = Data.DataLoader(
    dataset = torch_dataset,
    batch_size = BATCH_SIZE,
    shuffle=True,   # 是否打乱
    num_workers=2,   # 每次取值的线程数
)
print(loader)
for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader,start=1):
        print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',batch_x.numpy(), '| batch y: ', batch_y.numpy())
# 结果:
<torch.utils.data.dataloader.DataLoader object at 0x7f2db246f2b0>
Epoch:  0 | Step:  1 | batch x:  [10.  2.  1.  7.  3.  8.  6.  4.] | batch y:  [ 1.  9. 10.  4.  8.  3.  5.  7.]
Epoch:  0 | Step:  2 | batch x:  [5. 9.] | batch y:  [6. 2.]
Epoch:  1 | Step:  1 | batch x:  [ 5.  4.  3.  8.  1.  9.  2. 10.] | batch y:  [ 6.  7.  8.  3. 10.  2.  9.  1.]
Epoch:  1 | Step:  2 | batch x:  [7. 6.] | batch y:  [4. 5.]
Epoch:  2 | Step:  1 | batch x:  [ 7.  9.  3.  6. 10.  2.  4.  8.] | batch y:  [4. 2. 8. 5. 1. 9. 7. 3.]
Epoch:  2 | Step:  2 | batch x:  [1. 5.] | batch y:  [10.  6.]

这里的enumerate是python语法,类似于range,只不过它迭代的是一个可遍历的数据对象,比如列表、元组或者字符串。这里它会返回两个值,一个是索引,另一个是可遍历的对象。比如在这里,会返回索引(1,2),可遍历的对象(batch_x,batch_y),start=1意思是索引从1开始。

本文最后更新于2019年5月21日,已超过 1 年没有更新,如果文章内容或图片资源失效,请留言反馈,我们会及时处理,谢谢!