You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@thomasbbrunner, I have the same problem you had in #2372. I am trying recurrent PPO, used LSTMModule with python_based = True. I got the same error
"RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow."
In this part of the rnn script :
if self.recurrent_mode and is_init[..., 1:].any():
basically slicing is_init in is_init[..., 1:] gives the error. The code works with the shifted flag in GAE set True, but the performance of the PPO is bad.
Maybe you had an insight about the problem?
Motivation
I got the following error when I used GAE with an LSTM-based value network:
Here is the code I ran:
The error was caused by this exact line:
I tried using unbatched input and realized that
GAE
does not support unbatched input.For example, this is the unbatched input I tried:
And I got this error from
GAE
:Therefore, I concluded that
GAE
does not support an LSTM-based value network.Solution
GAE should support an LSTM-based value network.
Alternatives
GAE should support unbatched tensor dict as an input.
Additional context
I'm using
torchrl
version: 0.5.0.I found ticket #2372, which might be related to this issue, but I was not sure how to make my code work.
Checklist
The text was updated successfully, but these errors were encountered: