Skip to content

Commit

Permalink
[Feature] Rename _TensorDict into TensorDictBase (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
yoavnavon authored Jul 23, 2022
1 parent f07015d commit 23ca67c
Show file tree
Hide file tree
Showing 39 changed files with 548 additions and 494 deletions.
18 changes: 9 additions & 9 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
UnboundedContinuousTensorSpec,
OneHotDiscreteTensorSpec,
)
from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict
from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict
from torchrl.envs.common import _EnvClass

spec_dict = {
Expand Down Expand Up @@ -110,15 +110,15 @@ def _step(self, tensordict):
done = torch.tensor([done], dtype=torch.bool, device=self.device)
return TensorDict({"reward": n, "done": done, "next_observation": n}, [])

def _reset(self, tensordict: _TensorDict, **kwargs) -> _TensorDict:
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
self.max_val = max(self.counter + 100, self.counter * 2)

n = torch.tensor([self.counter]).to(self.device).to(torch.get_default_dtype())
done = self.counter >= self.max_val
done = torch.tensor([done], dtype=torch.bool, device=self.device)
return TensorDict({"done": done, "next_observation": n}, [])

def rand_step(self, tensordict: Optional[_TensorDict] = None) -> _TensorDict:
def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
return self.step(tensordict)


Expand All @@ -144,7 +144,7 @@ def _get_in_obs(self, obs):
def _get_out_obs(self, obs):
return obs

def _reset(self, tensordict: _TensorDict) -> _TensorDict:
def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
self.counter += 1
state = torch.zeros(self.size) + self.counter
tensordict = tensordict.select().set(
Expand All @@ -156,8 +156,8 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict:

def _step(
self,
tensordict: _TensorDict,
) -> _TensorDict:
tensordict: TensorDictBase,
) -> TensorDictBase:
tensordict = tensordict.to(self.device)
a = tensordict.get("action")
assert (a.sum(-1) == 1).all()
Expand Down Expand Up @@ -199,7 +199,7 @@ def _get_in_obs(self, obs):
def _get_out_obs(self, obs):
return obs

def _reset(self, tensordict: _TensorDict) -> _TensorDict:
def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
self.counter += 1
self.step_count = 0
state = torch.zeros(self.size) + self.counter
Expand All @@ -211,8 +211,8 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict:

def _step(
self,
tensordict: _TensorDict,
) -> _TensorDict:
tensordict: TensorDictBase,
) -> TensorDictBase:
self.step_count += 1
tensordict = tensordict.to(self.device)
a = tensordict.get("action")
Expand Down
4 changes: 2 additions & 2 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from _utils_internal import get_available_devices
from torch import nn, autograd
from torchrl.data.tensordict.tensordict import _TensorDict
from torchrl.data.tensordict.tensordict import TensorDictBase
from torchrl.modules import (
TanhNormal,
NormalParamWrapper,
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_delta(device, div_up, div_down):

def _map_all(*tensors_or_other, device):
for t in tensors_or_other:
if isinstance(t, (torch.Tensor, _TensorDict)):
if isinstance(t, (torch.Tensor, TensorDictBase)):
yield t.to(device)
else:
yield t
Expand Down
14 changes: 7 additions & 7 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LazyMemmapStorage,
LazyTensorStorage,
)
from torchrl.data.tensordict.tensordict import assert_allclose_td, _TensorDict
from torchrl.data.tensordict.tensordict import assert_allclose_td, TensorDictBase


collate_fn_dict = {
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_add(self, rbtype, storage, size, prefetch):
data = self._get_datum(rbtype)
rb.add(data)
s = rb._storage[0]
if isinstance(s, _TensorDict):
if isinstance(s, TensorDictBase):
assert (s == data.select(*s.keys())).all()
else:
assert (s == data).all()
Expand All @@ -142,12 +142,12 @@ def test_extend(self, rbtype, storage, size, prefetch):
for d in data[-length:]:
found_similar = False
for b in rb._storage:
if isinstance(b, _TensorDict):
if isinstance(b, TensorDictBase):
b = b.exclude("index").select(*set(d.keys()).intersection(b.keys()))
d = d.select(*set(d.keys()).intersection(b.keys()))

value = b == d
if isinstance(value, (torch.Tensor, _TensorDict)):
if isinstance(value, (torch.Tensor, TensorDictBase)):
value = value.all()
if value:
found_similar = True
Expand All @@ -160,18 +160,18 @@ def test_sample(self, rbtype, storage, size, prefetch):
data = self._get_data(rbtype, size=5)
rb.extend(data)
new_data = rb.sample(3)
if not isinstance(new_data, (torch.Tensor, _TensorDict)):
if not isinstance(new_data, (torch.Tensor, TensorDictBase)):
new_data = new_data[0]

for d in new_data:
found_similar = False
for b in data:
if isinstance(b, _TensorDict):
if isinstance(b, TensorDictBase):
b = b.exclude("index").select(*set(d.keys()).intersection(b.keys()))
d = d.select(*set(d.keys()).intersection(b.keys()))

value = b == d
if isinstance(value, (torch.Tensor, _TensorDict)):
if isinstance(value, (torch.Tensor, TensorDictBase)):
value = value.all()
if value:
found_similar = True
Expand Down
4 changes: 2 additions & 2 deletions test/test_tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
UnboundedContinuousTensorSpec,
OneHotDiscreteTensorSpec,
)
from torchrl.data.tensordict.tensordict import TensorDict, _TensorDict
from torchrl.data.tensordict.tensordict import TensorDict, TensorDictBase


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None])
Expand Down Expand Up @@ -376,7 +376,7 @@ def test_nested_composite_spec(self, is_complete, device, dtype):
ts = self._composite_spec(is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(is_complete, device, dtype)
td = ts.rand()
assert isinstance(td["nested_cp"], _TensorDict)
assert isinstance(td["nested_cp"], TensorDictBase)
keys = list(td.keys())
for key in keys:
if key != "nested_cp":
Expand Down
6 changes: 3 additions & 3 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
LazyStackedTensorDict,
stack as stack_td,
pad,
_TensorDict,
TensorDictBase,
)
from torchrl.data.tensordict.utils import _getitem_batch_size, convert_ellipsis_to_idx

Expand Down Expand Up @@ -833,7 +833,7 @@ def test_masking_set(self, td_name, device, from_list):
def zeros_like(item, n, d):
if isinstance(item, (MemmapTensor, torch.Tensor)):
return torch.zeros(n, *item.shape[d:], dtype=item.dtype, device=device)
elif isinstance(item, _TensorDict):
elif isinstance(item, TensorDictBase):
batch_size = item.batch_size
batch_size = [n, *batch_size[d:]]
out = TensorDict(
Expand Down Expand Up @@ -1344,7 +1344,7 @@ def test_flatten_keys(self, td_name, device, inplace, separator):

td_flatten = td.flatten_keys(inplace=inplace, separator=separator)
for key, value in td_flatten.items():
assert not isinstance(value, _TensorDict)
assert not isinstance(value, TensorDictBase)
assert (
separator.join(["nested_tensordict", "nested_nested_tensordict", "a"])
in td_flatten.keys()
Expand Down
Loading

0 comments on commit 23ca67c

Please sign in to comment.