在tensorflow 1中,有一个层tf.compat.v1.keras.layers.CuDNNLSTM是为使用cuDNN而构建的,而在tensorflow 2中,这个层被废弃为支持使用
1. `activation` == `tanh`
2. `recurrent_activation` == `sigmoid`
3. `recurrent_dropout` == 0
4. `unroll` is `False`
5. `use_bias` is `True`
6. Inputs are not masked or strictly right padded.用于cuDNN实现。我不知道是否存在bug或一些未实现的差异,但似乎与使用输入偏差和递归偏差的CuDNNLSTM不同,其中LSTM在上述tf2 cuDNN规则下只使用反复偏倚。
相关代码
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
print(tf.__version__)
model1 = Sequential()
model1.add(LSTM(1, activation='tanh', recurrent_dropout=0, unroll=False, use_bias=True, return_sequences=0, input_shape=(1, 1)))
print(model1.summary())
model2 = Sequential()
model2.add(CuDNNLSTM(1, return_sequences=0, input_shape=(1, 1)))
print(model2.summary())2.2.0
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm (LSTM) (None, 1) 12
=================================================================
Total params: 12
Trainable params: 12
Non-trainable params: 0
_________________________________________________________________
None
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
cu_dnnlstm (CuDNNLSTM) (None, 1) 16
=================================================================
Total params: 16
Trainable params: 16
Non-trainable params: 0
_________________________________________________________________注意,总参数因N_units * 4不同而不同,这意味着每个单元缺少一个附加的偏置向量。
请注意,LSTM的pytorch实现与tf1 CuDNNLSTM匹配,这是我偶然发现的。
有没有什么解决办法我错过了,还是应该提升到github的问题?
发布于 2020-10-15 14:46:48
不,这不是虫子。
CuDNNLSTM中的2x偏压是recurrent kernel的独立偏倚。
当CuDNNLSTM在tf.keras.layers.LSTM中可用时,您可以看到代码是以这样一种方式编写的,即它不对recurrent kernel使用单独的偏向,而是调用LSTMCell,这是一个基类,没有单独的偏向。
您可以使用model.layers[0].trainable_weights查看两个实现之间偏差的形状差异。
https://stackoverflow.com/questions/62840495
复制相似问题