Skip to content

Commit

Permalink
Simplify cuDNN logic
Browse files Browse the repository at this point in the history
  • Loading branch information
felker committed Sep 12, 2022
1 parent e35f122 commit 1ba2f7c
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions plasma/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Convolution1D, MaxPooling1D, TimeDistributed,
Concatenate
)
CuDNNLSTM = LSTM
# from tensorflow.compat.v1.keras.layers import CuDNNLSTM
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.regularizers import l2 # l1, l1_l2
Expand Down Expand Up @@ -137,10 +136,8 @@ def build_model(self, predict, custom_batch_size=None):
if custom_batch_size is not None:
batch_size = custom_batch_size

if rnn_type == 'LSTM':
if rnn_type == 'LSTM' or rnn_type == 'CuDNNLSTM' or rnn_type == 'cuDNNLSTM':
rnn_model = LSTM
elif rnn_type == 'CuDNNLSTM':
rnn_model = CuDNNLSTM
elif rnn_type == 'SimpleRNN':
rnn_model = SimpleRNN
else:
Expand Down

0 comments on commit 1ba2f7c

Please sign in to comment.