Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GAE does not support LSTM-based value network. #2444

Open
1 task done
levelrin opened this issue Sep 19, 2024 · 2 comments
Open
1 task done

GAE does not support LSTM-based value network. #2444

levelrin opened this issue Sep 19, 2024 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@levelrin
Copy link

levelrin commented Sep 19, 2024

Motivation

I got the following error when I used GAE with an LSTM-based value network:

RuntimeError: Batching rule not implemented for aten::lstm.input. We could not generate a fallback.

Here is the code I ran:

import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.objectives.value import GAE

class ValueNetwork(nn.Module):

    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=2,
            hidden_size=1,
            num_layers=1,
            bidirectional=False,
            batch_first=True
        )

    def forward(self, i):
        output, (hidden_state, cell_state) = self.lstm(i)
        return hidden_state

def main():
    value_network = ValueNetwork()
    value_dict_module = TensorDictModule(value_network, in_keys=["observation"], out_keys=["value"])
    gae = GAE(
        gamma=0.98,
        lmbda=0.95,
        value_network=value_dict_module
    )
    gae.set_keys(
        advantage="advantage",
        value_target="value_target",
        value="value",
    )
    tensor_dict = TensorDict({
        "next": {
            "observation": torch.FloatTensor([
                [[8, 9], [10, 11]],
                [[12, 13], [14, 15]]
            ]),
            "reward": torch.FloatTensor([[1], [-1]]),
            "done": torch.BoolTensor([[1], [1]]),
            "terminated": torch.BoolTensor([[1], [1]])
        },
        "observation": torch.FloatTensor([
            [[0, 1], [2, 3]],
            [[4, 5], [6, 7]]
        ])
    }, batch_size=2)

    output_tensor_dict = gae(tensor_dict)
    print(f"output_tensor_dict: {output_tensor_dict}")
    advantage = output_tensor_dict["advantage"]
    print(f"advantage: {advantage}")

main()

The error was caused by this exact line:

output_tensor_dict = gae(tensor_dict)

I tried using unbatched input and realized that GAE does not support unbatched input.
For example, this is the unbatched input I tried:

tensor_dict = TensorDict({
    "next": {
        "observation": torch.FloatTensor([[4, 5], [6, 7]]),
        "reward": torch.FloatTensor([1]),
        "done": torch.BoolTensor([1]),
        "terminated": torch.BoolTensor([1])
    },
    "observation": torch.FloatTensor([[0, 1], [2, 3]])
}, batch_size=[])

And I got this error from GAE:

RuntimeError: Expected input tensordict to have at least one dimensions, got tensordict.batch_size = torch.Size([])

Therefore, I concluded that GAE does not support an LSTM-based value network.

Solution

GAE should support an LSTM-based value network.

Alternatives

GAE should support unbatched tensor dict as an input.

Additional context

I'm using torchrl version: 0.5.0.

I found ticket #2372, which might be related to this issue, but I was not sure how to make my code work.

Checklist

  • I have checked that there is no similar issue in the repo (required)
@levelrin levelrin added the enhancement New feature or request label Sep 19, 2024
@thomasbbrunner
Copy link
Contributor

I see that you're using the Torch LSTM in your snippet. Maybe try with TorchRL's version of it (LSTMModule)?

Also, as described in #2372, you'll need to use python_based=True in your LSTMModule.

@MahmoudUwk
Copy link

MahmoudUwk commented Oct 21, 2024

@thomasbbrunner, I have the same problem you had in #2372. I am trying recurrent PPO, used LSTMModule with python_based = True. I got the same error

"RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow."

In this part of the rnn script :

if self.recurrent_mode and is_init[..., 1:].any():

basically slicing is_init in is_init[..., 1:] gives the error. The code works with the shifted flag in GAE set True, but the performance of the PPO is bad.
Maybe you had an insight about the problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants