From c9f6ac573b65f4aef6308d1f0921e08cecc643a7 Mon Sep 17 00:00:00 2001 From: lgulich <22480644+lgulich@users.noreply.github.com> Date: Fri, 13 Dec 2024 05:47:14 +0100 Subject: [PATCH] Fixes configclass dict conversion for torch tensors (#1530) # Description Fix configclass dict conversion for torch tensors Up to v1.2.0 if a configclass would contain a list/tuple of torch tensors it would be left as is. \#1227 changed the behavior of converting lists/tuples in a dict, which means that currently torch tensors are converted to an empty dict, effectively losing all contained data. The underlying issue is that `torch.tensor.__dict__` returns an empty dict, which was (luckily) ignored previously because we did not convert the contents of lists. This MR fixes this by treating torch tensors specially. I don't like having a special case for a non-builtin class but given that IsaacLab is heavily married with torch tensors I think it's ok in this case. Since currently the behavior is different between 1.2 and 1.3: can we cherry pick this change to the 1.3 branch? ## Type of change - Bug fix (non-breaking change which fixes an issue) ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [ ] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there Co-authored-by: Kelly Guo --- .../omni.isaac.lab/omni/isaac/lab/utils/dict.py | 7 +++++++ .../omni.isaac.lab/test/utils/test_configclass.py | 15 +++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/utils/dict.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/utils/dict.py index 07086a1f9b..e695207c88 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/utils/dict.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/utils/dict.py @@ -8,6 +8,7 @@ import collections.abc import hashlib import json +import torch from collections.abc import Iterable, Mapping from typing import Any @@ -40,6 +41,11 @@ def class_to_dict(obj: object) -> dict[str, Any]: # convert object to dictionary if isinstance(obj, dict): obj_dict = obj + elif isinstance(obj, torch.Tensor): + # We have to treat torch tensors specially because `torch.tensor.__dict__` returns an empty + # dict, which would mean that a torch.tensor would be stored as an empty dict. Instead we + # want to store it directly as the tensor. + return obj elif hasattr(obj, "__dict__"): obj_dict = obj.__dict__ else: @@ -57,6 +63,7 @@ def class_to_dict(obj: object) -> dict[str, Any]: # check if attribute is a dictionary elif hasattr(value, "__dict__") or isinstance(value, dict): data[key] = class_to_dict(value) + # check if attribute is a list or tuple elif isinstance(value, (list, tuple)): data[key] = type(value)([class_to_dict(v) for v in value]) else: diff --git a/source/extensions/omni.isaac.lab/test/utils/test_configclass.py b/source/extensions/omni.isaac.lab/test/utils/test_configclass.py index 4b2f5a7ff1..bb4b3e5999 100644 --- a/source/extensions/omni.isaac.lab/test/utils/test_configclass.py +++ b/source/extensions/omni.isaac.lab/test/utils/test_configclass.py @@ -19,6 +19,7 @@ import copy import os +import torch import unittest from collections.abc import Callable from dataclasses import MISSING, asdict, field @@ -134,6 +135,14 @@ def __post_init__(self): self.add_variable = 3 +@configclass +class BasicDemoTorchCfg: + """Dummy configuration class with a torch tensor .""" + + some_number: int = 0 + some_tensor: torch.Tensor = torch.Tensor([1, 2, 3]) + + """ Dummy configuration to check type annotations ordering. """ @@ -515,6 +524,12 @@ def test_dict_conversion(self): self.assertDictEqual(cfg.to_dict(), basic_demo_cfg_correct) self.assertDictEqual(cfg.env.to_dict(), basic_demo_cfg_correct["env"]) + torch_cfg = BasicDemoTorchCfg() + torch_cfg_dict = torch_cfg.to_dict() + # We have to do a manual check because torch.Tensor does not work with assertDictEqual. + self.assertEqual(torch_cfg_dict["some_number"], 0) + self.assertTrue(torch.all(torch_cfg_dict["some_tensor"] == torch.tensor([1, 2, 3]))) + def test_dict_conversion_order(self): """Tests that order is conserved when converting to dictionary.""" true_outer_order = ["device_id", "env", "robot_default_state", "list_config"]