这篇博文主要讲pytorch的基本函数理解。
1 torch.arange(min,max,stride)
类似于python中的range
1 2 |
print(torch.arange(1,10,3)) > tensor([1, 4, 7]) |
可以看到,这相当于从1~10,每隔3取一个数。
2 torch.normal(mean,std)
生成平均值为mean,标准差为std的随机数。例如
1 |
torch.normal(means=torch.arange(1, 10 ,3), std=torch.arange(1, 0, -0.3)) |
注意,在最新的pytorch1.x运行上面语句时,会出现错误:
1 2 3 4 5 6 7 8 |
TypeError Traceback (most recent call last) <ipython-input-74-22573061bc42> in <module> 1 torch.cuda.is_available() 2 ----> 3 torch.normal(torch.arange(1., 11.,3.),torch.arange(1, 0, -0.3))TypeError: normal() received an invalid combination of arguments - got (std=Tensor, means=Tensor, ), but expected one of: * (Tensor mean, Tensor std, torch.Generator generator, Tensor out) * (Tensor mean, float std, torch.Generator generator, Tensor out) * (float mean, Tensor std, torch.Generator generator, Tensor out |
这是因为数据类型不一致,改为:
1 2 |
torch.normal(torch.arange(1., 11.,3.),torch.arange(1, 0, -0.3)) > tensor([0.2560, 3.9031, 6.1879, 9.8848]) |
即可!
这里0.2560是从均值为1,标准差为1的正态分布中随机生成的数,同理,3.9031是从均值为4,标准差为0.7的正态分布中随机生成的。
3 torch.cat(a,b,0/1)
这是用于拼接两个张量的函数,比如:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
>>> import torch >>> A=torch.ones(2,3) #2x3的张量(矩阵) >>> A tensor([[ 1., 1., 1.], [ 1., 1., 1.]]) >>> B=2*torch.ones(4,3)#4x3的张量(矩阵) >>> B tensor([[ 2., 2., 2.], [ 2., 2., 2.], [ 2., 2., 2.], [ 2., 2., 2.]]) >>> C=torch.cat((A,B),0)#按维数0(行)拼接 >>> C tensor([[ 1., 1., 1.], [ 1., 1., 1.], [ 2., 2., 2.], [ 2., 2., 2.], [ 2., 2., 2.], [ 2., 2., 2.]]) >>> C.size() torch.Size([6, 3]) >>> D=2*torch.ones(2,4) #2x4的张量(矩阵) >>> C=torch.cat((A,D),1)#按维数1(列)拼接 >>> C tensor([[ 1., 1., 1., 2., 2., 2., 2.], [ 1., 1., 1., 2., 2., 2., 2.]]) >>> C.size() torch.Size([2, 7]) |
这里的尾缀0或者1代表按照什么进行拼接,如果是0就按照行进行拼接,如果是1就按照列进行拼接。
4 DataLoader()
这是用于将大量数据分为小batch的。比如一共1000张图片,我可以分为10个batch,每个batch size为100.
这个函数位于:torch.utils.data
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch import torch.utils.data as Data BATCH_SIZE = 8 x = torch.linspace(1,10,10) y = torch.linspace(10,1,10) torch_dataset = Data.TensorDataset(x,y) loader = Data.DataLoader( dataset = torch_dataset, batch_size = BATCH_SIZE, shuffle=True, # 是否打乱 num_workers=2, # 每次取值的线程数 ) print(loader) for epoch in range(3): for step,(batch_x,batch_y) in enumerate(loader,start=1): print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',batch_x.numpy(), '| batch y: ', batch_y.numpy()) # 结果: <torch.utils.data.dataloader.DataLoader object at 0x7f2db246f2b0> Epoch: 0 | Step: 1 | batch x: [10. 2. 1. 7. 3. 8. 6. 4.] | batch y: [ 1. 9. 10. 4. 8. 3. 5. 7.] Epoch: 0 | Step: 2 | batch x: [5. 9.] | batch y: [6. 2.] Epoch: 1 | Step: 1 | batch x: [ 5. 4. 3. 8. 1. 9. 2. 10.] | batch y: [ 6. 7. 8. 3. 10. 2. 9. 1.] Epoch: 1 | Step: 2 | batch x: [7. 6.] | batch y: [4. 5.] Epoch: 2 | Step: 1 | batch x: [ 7. 9. 3. 6. 10. 2. 4. 8.] | batch y: [4. 2. 8. 5. 1. 9. 7. 3.] Epoch: 2 | Step: 2 | batch x: [1. 5.] | batch y: [10. 6.] |
这里的enumerate是python语法,类似于range,只不过它迭代的是一个可遍历的数据对象,比如列表、元组或者字符串。这里它会返回两个值,一个是索引,另一个是可遍历的对象。比如在这里,会返回索引(1,2),可遍历的对象(batch_x,batch_y),start=1意思是索引从1开始。
本文最后更新于2019年5月21日,已超过 1 年没有更新,如果文章内容或图片资源失效,请留言反馈,我们会及时处理,谢谢!