Pytorch中LSTM架构代码解释
总结
原文
*class*torch.nn.LSTM(*input_size*, *hidden_size*, *num_layers=1*, *bias=True*, *batch_first=False*, *dropout=0.0*,
*bidirectional=False*, *proj_size=0*, *device=None*, *dtype=None*)
Apply a multi-layer long short-term memory (LSTM) RNN to an input sequence.
将多层长短期记忆 (LSTM) RNN 应用于输入序列。
For each element in the input sequence, each layer computes the following function:
对于输入序列中的每个元素,每层计算以下函数:
i t = σ ( W i i x t + b i i + W h i h t − 1 + b h i ) i_t=\sigma(W_{ii}x_t+b_{ii}+W_{hi}h_{t-1}+b_{hi}) it=σ(Wiixt+bii+Whiht−1+bhi)
f t = σ ( W i f x t + b i f + W h f h t − 1 + b h f ) \large f_t=\sigma(W_{if}x_t+b_{if}+W_{hf}h_{t-1}+b_{hf}) ft=σ(Wifxt+bif+Whfht−1+bhf)
g t = tanh ( W i g x t + b i g + W h g h t − 1 + b h g ) \large g_t=\tanh(W_{ig}x_t+b_{ig}+W_{hg}h_{t-1}+b_{hg}) gt=tanh(Wigxt+big+Whght−1+bhg)
o t = σ ( W i o x t + b i o + W h o h t − 1 + b h o ) o_t=\sigma(W_{io}x_t+b_{io}+W_{ho}h_{t-1}+b_{ho}) ot=σ(Wioxt+bio+Whoht−1+bho)
c t = f t ⊙ c t − 1 + i t ⊙ g t c_t=f_t\odot c_{t-1}+i_t\odot g_t ct=ft⊙ct−1+i