马春杰杰博客
致力于深度学习经验分享!

[mcj]pytorch中LSTM的输入输出解释||LSTM输入输出详解

最近想了解一些关于LSTM的相关知识,在进行代码测试的时候,有个地方一直比较疑惑,关于LSTM的输入和输出问题。一直不清楚在pytorch里面该如何定义LSTM的输入和输出。首先看个pytorch官方的例子:

在这里如果我们打印output、hn、cn的shape,我们可以看到,torch的输出已经变成了定义中的20。

接着来看一下LSTM的参数都有哪些:

LSTM一共有7个参数,其中前三个是必须的,分别为:input_size, hidden_size, num_layers.

1 input_size

在这里首先对输入解释一下,nn.LSTM()的第一个参数为输入的序列维度,它对应着torch.randn()中的第三个参数10。可能有人不太明白这个这个函数是怎么回事,在这里解释一下:

torch.randn(5, 3, 10)会生成五组数据,每组数据有3行10列。如果用在视频中的话,这里的5等于每个视频抽取的帧数,如果视频分辨率为100*100,则第二个参数为10000,若视频为彩色三通道的话,第三个参数为3,即输入序列变为(5,10000,3),看一下这张图:

[mcj]pytorch中LSTM的输入输出解释||LSTM输入输出详解

上图是一个完整的LSTM流程,上面生成的五组数据就对应了五个A,即一个LSTM中有五个神经元。

再举个例子,比如现在有5个句子,每个句子由3个单词组成,每个单词用10维的向量组成,这样参数为:seq_len=3, batch=5, input_size=10.

输入LSTM中的X数据格式尺寸为(seq_len, batch, input_size),此外h0和c0尺寸如下

h0(num_layers * num_directions, batch_size, hidden_size)

c0(num_layers * num_directions, batch_size, hidden_size)

2 hidden_size

[mcj]pytorch中LSTM的输入输出解释||LSTM输入输出详解

对照上图可以看出,隐藏层数即为中间的节点数量。这个数量可以由用户自定义。

3 num_layers

这个是LSTM的层数,默认是1,如果我们设置为2的话,第一层计算得到h,然后把h作为输入,输给第二层。然后在最后输出最终的O。

4 bias

表示是否添加bias偏置,默认为true

5 batch_first

与LSTM的输入格式有关。

输入输出的第一维是否为 batch_size,默认值 False。因为 Torch 中,人们习惯使用Torch中带有的dataset,dataloader向神经网络模型连续输入数据,这里面就有一个 batch_size 的参数,表示一次输入多少个数据。 在 LSTM 模型中,输入数据必须是一批数据,为了区分LSTM中的批量数据和dataloader中的批量数据是否相同意义,LSTM 模型就通过这个参数的设定来区分。 如果是相同意义的,就设置为True,如果不同意义的,设置为False。 torch.LSTM 中 batch_size 维度默认是放在第二维度,故此参数设置可以将 batch_size 放在第一维度。如:input 默认是(4,1,5),中间的 1 是 batch_size,指定batch_first=True后就是(1,4,5)。所以,如果你的输入数据是二维数据的话,就应该将 batch_first 设置为True;

6 dropout

是否进行dropout操作,默认为0,输入值范围为0~1的小数,表示每次丢弃的百分比。一般用来防止过拟合。

7 bidirectional

是否进行双向RNN,默认为false。

[mcj]pytorch中LSTM的输入输出解释||LSTM输入输出详解

运行模型:

运行模型的格式是这样写的。output, (hn, cn) = model(input, (h0, c0))

从形式上看,输入结构和输出结构是一样的。都是3个输入,3个输出。

参数1:你输入的数据团。好像必须是 3 维数据。但必须注意 batch_size 的位置。是第一维,还是第二维。默认是在第二维度。是不可变的维度。最后一个维度是行数据的个数。剩下的1维数据是可变的,这就是长短数据。默认放在第一维。

参数2:隐藏层数据,也必须是3维的,第一维:是LSTM的层数,第二维:是隐藏层的batch_size数,必须和输入数据的batch_size一致。第三维:是隐藏层节点数,必须和模型实例时的参数一致。

参数3:传递层数据,也必须是3维的,通常和参数2的设置一样。它的作用是LSTM内部循环中的记忆体,用来结合新的输入一起计算。

本文最后更新于2021年6月2日,已超过 1 年没有更新,如果文章内容或图片资源失效,请留言反馈,我们会及时处理,谢谢!

如果你对这篇文章有什么疑问或建议,欢迎下面留言提出,我看到会立刻回复!

打赏
未经允许不得转载:马春杰杰 » [mcj]pytorch中LSTM的输入输出解释||LSTM输入输出详解
超级便宜的原生ChatGPT4.0

留个评论吧~ 2 评论前登陆可免验证码!

私密评论
  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址(选填,便于回访^_^)
  1. #1

    图片失效了,麻烦更新一下哦

    江述南柯 4年前 (2021-05-29) 来自天朝的朋友 谷歌浏览器 Windows 10 回复

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

微信扫一扫打赏

登录

忘记密码 ?

切换登录

注册