Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes configclass dict conversion for torch tensors #1530

Merged
merged 3 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import collections.abc
import hashlib
import json
import torch
from collections.abc import Iterable, Mapping
from typing import Any

Expand Down Expand Up @@ -40,6 +41,11 @@ def class_to_dict(obj: object) -> dict[str, Any]:
# convert object to dictionary
if isinstance(obj, dict):
obj_dict = obj
elif isinstance(obj, torch.Tensor):
# We have to treat torch tensors specially because `torch.tensor.__dict__` returns an empty
# dict, which would mean that a torch.tensor would be stored as an empty dict. Instead we
# want to store it directly as the tensor.
return obj
elif hasattr(obj, "__dict__"):
obj_dict = obj.__dict__
else:
Expand All @@ -57,6 +63,7 @@ def class_to_dict(obj: object) -> dict[str, Any]:
# check if attribute is a dictionary
elif hasattr(value, "__dict__") or isinstance(value, dict):
data[key] = class_to_dict(value)
# check if attribute is a list or tuple
elif isinstance(value, (list, tuple)):
data[key] = type(value)([class_to_dict(v) for v in value])
else:
Expand Down
15 changes: 15 additions & 0 deletions source/extensions/omni.isaac.lab/test/utils/test_configclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import copy
import os
import torch
import unittest
from collections.abc import Callable
from dataclasses import MISSING, asdict, field
Expand Down Expand Up @@ -134,6 +135,14 @@ def __post_init__(self):
self.add_variable = 3


@configclass
class BasicDemoTorchCfg:
"""Dummy configuration class with a torch tensor ."""

some_number: int = 0
some_tensor: torch.Tensor = torch.Tensor([1, 2, 3])


"""
Dummy configuration to check type annotations ordering.
"""
Expand Down Expand Up @@ -515,6 +524,12 @@ def test_dict_conversion(self):
self.assertDictEqual(cfg.to_dict(), basic_demo_cfg_correct)
self.assertDictEqual(cfg.env.to_dict(), basic_demo_cfg_correct["env"])

torch_cfg = BasicDemoTorchCfg()
torch_cfg_dict = torch_cfg.to_dict()
# We have to do a manual check because torch.Tensor does not work with assertDictEqual.
self.assertEqual(torch_cfg_dict["some_number"], 0)
self.assertTrue(torch.all(torch_cfg_dict["some_tensor"] == torch.tensor([1, 2, 3])))

def test_dict_conversion_order(self):
"""Tests that order is conserved when converting to dictionary."""
true_outer_order = ["device_id", "env", "robot_default_state", "list_config"]
Expand Down
Loading