Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] info['_weight'] device for Importance Sampling in PER #2518

Open
3 tasks done
EladSharony opened this issue Oct 26, 2024 · 3 comments
Open
3 tasks done

[BUG] info['_weight'] device for Importance Sampling in PER #2518

EladSharony opened this issue Oct 26, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@EladSharony
Copy link

Describe the bug

The device of info['_weight'] doesn't match the storage device.

To Reproduce

# From documentation
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
from tensordict import TensorDict
rb = ReplayBuffer(storage=LazyTensorStorage(10, device=torch.device('cuda')), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
rb.add(data_0)
rb.add(data_1)
rb.update_priority(torch.tensor([0, 1]), priority=priority)
sample, info = rb.sample(10, return_info=True)

# Check devices
print(f"sample device: {sample.device}\n"
      f"info['_weight'] device: {info['_weight'].device}")
sample device: cuda:0
info['_weight'] device: cpu

Expected behavior

Both should be on the same device defined in storage(..., device) as these weights are later used to compute the loss.

System info

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2024.10.23 1.26.4 3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0] linux

Reason and Possible fixes

Specify device argument in samplers.py (L508):

weight = torch.as_tensor(self._sum_tree[index], device=storage.device)

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@EladSharony EladSharony added the bug Something isn't working label Oct 26, 2024
@vmoens
Copy link
Contributor

vmoens commented Oct 26, 2024

That and also we should be able to execute this directly on device. I'll push some changes

@vmoens
Copy link
Contributor

vmoens commented Oct 29, 2024

Just FYI you could do this instead:

# From documentation
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler, TensorDictReplayBuffer
from tensordict import TensorDict
import torch

rb = TensorDictReplayBuffer(storage=LazyTensorStorage(10, device=torch.device('cuda')),
                  sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))

priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
rb.add(data_0)
rb.add(data_1)
rb.update_priority(torch.tensor([0, 1]), priority=priority)
sample = rb.sample(10)


# Check devices
print(f"sample device: {sample.device}\n"
      f"sample['_weight'] device: {sample['_weight'].device}")

which will put your weights on cuda.

There are two issues in patching the PRB to account for the device of the storage:

  1. The issue you're having is caused by the fact that, for the ReplayBuffer class, the device of the storage is unknown, but it could be None. Also, the sampler is unaware of what the storage is. You could have multiple storages for instance. So in practice, if we want to cast the content of the info dict to the storage device, we would need to pass the storage device to the sampler and do that transfer. Another option could be for the buffer (and not the sampler) to do the casting if and only if the info dict is required (that would avoid useless H2D transfers when the info dict isn't asked for) but then we would still face the issue (2) below.

  2. If we map the info from the PRB to the device of the storage, it may still be incomplete. In the following example, I patch the sample method but also append a device map as a transform in the buffer. As this example shows, our transform will rightfully ignore the info dict:

# From documentation
import functools

from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
from tensordict import TensorDict
import torch

device = "cuda"

# patch
sample = PrioritizedSampler.sample
@functools.wraps(sample)
def new_sample(self, *args, **kwargs):
    out = sample(self, *args, **kwargs)
    out = torch.utils._pytree.tree_map(lambda x: x.to(device), out)
    return out
PrioritizedSampler.sample = new_sample

rb = ReplayBuffer(storage=LazyTensorStorage(10, device=torch.device(device)), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
# map back content on cpu
rb.append_transform(lambda x: x.to("cpu"))

priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
rb.add(data_0)
rb.add(data_1)
rb.update_priority(torch.tensor([0, 1]), priority=priority)
sample, info = rb.sample(10, return_info=True)

# Check devices
print(f"sample device: {sample.device}\n"
      f"info['_weight'] device: {info['_weight'].device}")

So to recap:
PRB is currenlty only hosted on CPU. It's the only part of the lib that relies on C++ code. The fact that the compuation is done on CPU is why you're getting info dict on cpu. Mapping to the storage device could be done We could do the sumtree and mintree on CUDA, that shouldn't be too hard. In the meantime we can send the info dict content to the storage device (see #2527) but that will only be an incomplete patch if you're not using TensorDictReplayBuffer.

@EladSharony
Copy link
Author

Also, the sampler is unaware of what the storage is. You could have multiple storages for instance.

Maybe I'm missing something, but def sample(self, storage: Storage, batch_size: int) accepts the storage as an argument, thus we can query storage.device - which will also cover the multiple storages case.

If we map the info from the PRB to the device of the storage, it may still be incomplete. In the following example, I patch the sample method but also append a device map as a transform in the buffer. As this example shows, our transform will rightfully ignore the info dict:

That's a valid point. I wanted to suggest adding info to the data, but preallocating the memory might not be that trivial. On the other hand, I can't think of any reason (besides mapping a device) for which one will need to transform the info dict.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants